1 // SPDX-License-Identifier: GPL-2.0-only
2 #include <linux/module.h>
3 
4 #include <net/sock.h>
5 #include <linux/netlink.h>
6 #include <linux/sock_diag.h>
7 #include <linux/netlink_diag.h>
8 #include <linux/rhashtable.h>
9 
10 #include "af_netlink.h"
11 
sk_diag_dump_groups(struct sock * sk,struct sk_buff * nlskb)12 static int sk_diag_dump_groups(struct sock *sk, struct sk_buff *nlskb)
13 {
14 	struct netlink_sock *nlk = nlk_sk(sk);
15 
16 	if (nlk->groups == NULL)
17 		return 0;
18 
19 	return nla_put(nlskb, NETLINK_DIAG_GROUPS, NLGRPSZ(nlk->ngroups),
20 		       nlk->groups);
21 }
22 
sk_diag_put_flags(struct sock * sk,struct sk_buff * skb)23 static int sk_diag_put_flags(struct sock *sk, struct sk_buff *skb)
24 {
25 	struct netlink_sock *nlk = nlk_sk(sk);
26 	u32 flags = 0;
27 
28 	if (nlk->cb_running)
29 		flags |= NDIAG_FLAG_CB_RUNNING;
30 	if (nlk->flags & NETLINK_F_RECV_PKTINFO)
31 		flags |= NDIAG_FLAG_PKTINFO;
32 	if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
33 		flags |= NDIAG_FLAG_BROADCAST_ERROR;
34 	if (nlk->flags & NETLINK_F_RECV_NO_ENOBUFS)
35 		flags |= NDIAG_FLAG_NO_ENOBUFS;
36 	if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID)
37 		flags |= NDIAG_FLAG_LISTEN_ALL_NSID;
38 	if (nlk->flags & NETLINK_F_CAP_ACK)
39 		flags |= NDIAG_FLAG_CAP_ACK;
40 
41 	return nla_put_u32(skb, NETLINK_DIAG_FLAGS, flags);
42 }
43 
sk_diag_fill(struct sock * sk,struct sk_buff * skb,struct netlink_diag_req * req,u32 portid,u32 seq,u32 flags,int sk_ino)44 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
45 			struct netlink_diag_req *req,
46 			u32 portid, u32 seq, u32 flags, int sk_ino)
47 {
48 	struct nlmsghdr *nlh;
49 	struct netlink_diag_msg *rep;
50 	struct netlink_sock *nlk = nlk_sk(sk);
51 
52 	nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
53 			flags);
54 	if (!nlh)
55 		return -EMSGSIZE;
56 
57 	rep = nlmsg_data(nlh);
58 	rep->ndiag_family	= AF_NETLINK;
59 	rep->ndiag_type		= sk->sk_type;
60 	rep->ndiag_protocol	= sk->sk_protocol;
61 	rep->ndiag_state	= sk->sk_state;
62 
63 	rep->ndiag_ino		= sk_ino;
64 	rep->ndiag_portid	= nlk->portid;
65 	rep->ndiag_dst_portid	= nlk->dst_portid;
66 	rep->ndiag_dst_group	= nlk->dst_group;
67 	sock_diag_save_cookie(sk, rep->ndiag_cookie);
68 
69 	if ((req->ndiag_show & NDIAG_SHOW_GROUPS) &&
70 	    sk_diag_dump_groups(sk, skb))
71 		goto out_nlmsg_trim;
72 
73 	if ((req->ndiag_show & NDIAG_SHOW_MEMINFO) &&
74 	    sock_diag_put_meminfo(sk, skb, NETLINK_DIAG_MEMINFO))
75 		goto out_nlmsg_trim;
76 
77 	if ((req->ndiag_show & NDIAG_SHOW_FLAGS) &&
78 	    sk_diag_put_flags(sk, skb))
79 		goto out_nlmsg_trim;
80 
81 	nlmsg_end(skb, nlh);
82 	return 0;
83 
84 out_nlmsg_trim:
85 	nlmsg_cancel(skb, nlh);
86 	return -EMSGSIZE;
87 }
88 
__netlink_diag_dump(struct sk_buff * skb,struct netlink_callback * cb,int protocol,int s_num)89 static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
90 				int protocol, int s_num)
91 {
92 	struct rhashtable_iter *hti = (void *)cb->args[2];
93 	struct netlink_table *tbl = &nl_table[protocol];
94 	struct net *net = sock_net(skb->sk);
95 	struct netlink_diag_req *req;
96 	struct netlink_sock *nlsk;
97 	struct sock *sk;
98 	int num = 2;
99 	int ret = 0;
100 
101 	req = nlmsg_data(cb->nlh);
102 
103 	if (s_num > 1)
104 		goto mc_list;
105 
106 	num--;
107 
108 	if (!hti) {
109 		hti = kmalloc(sizeof(*hti), GFP_KERNEL);
110 		if (!hti)
111 			return -ENOMEM;
112 
113 		cb->args[2] = (long)hti;
114 	}
115 
116 	if (!s_num)
117 		rhashtable_walk_enter(&tbl->hash, hti);
118 
119 	rhashtable_walk_start(hti);
120 
121 	while ((nlsk = rhashtable_walk_next(hti))) {
122 		if (IS_ERR(nlsk)) {
123 			ret = PTR_ERR(nlsk);
124 			if (ret == -EAGAIN) {
125 				ret = 0;
126 				continue;
127 			}
128 			break;
129 		}
130 
131 		sk = (struct sock *)nlsk;
132 
133 		if (!net_eq(sock_net(sk), net))
134 			continue;
135 
136 		if (sk_diag_fill(sk, skb, req,
137 				 NETLINK_CB(cb->skb).portid,
138 				 cb->nlh->nlmsg_seq,
139 				 NLM_F_MULTI,
140 				 sock_i_ino(sk)) < 0) {
141 			ret = 1;
142 			break;
143 		}
144 	}
145 
146 	rhashtable_walk_stop(hti);
147 
148 	if (ret)
149 		goto done;
150 
151 	rhashtable_walk_exit(hti);
152 	num++;
153 
154 mc_list:
155 	read_lock(&nl_table_lock);
156 	sk_for_each_bound(sk, &tbl->mc_list) {
157 		if (sk_hashed(sk))
158 			continue;
159 		if (!net_eq(sock_net(sk), net))
160 			continue;
161 		if (num < s_num) {
162 			num++;
163 			continue;
164 		}
165 
166 		if (sk_diag_fill(sk, skb, req,
167 				 NETLINK_CB(cb->skb).portid,
168 				 cb->nlh->nlmsg_seq,
169 				 NLM_F_MULTI,
170 				 sock_i_ino(sk)) < 0) {
171 			ret = 1;
172 			break;
173 		}
174 		num++;
175 	}
176 	read_unlock(&nl_table_lock);
177 
178 done:
179 	cb->args[0] = num;
180 
181 	return ret;
182 }
183 
netlink_diag_dump(struct sk_buff * skb,struct netlink_callback * cb)184 static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
185 {
186 	struct netlink_diag_req *req;
187 	int s_num = cb->args[0];
188 	int err = 0;
189 
190 	req = nlmsg_data(cb->nlh);
191 
192 	if (req->sdiag_protocol == NDIAG_PROTO_ALL) {
193 		int i;
194 
195 		for (i = cb->args[1]; i < MAX_LINKS; i++) {
196 			err = __netlink_diag_dump(skb, cb, i, s_num);
197 			if (err)
198 				break;
199 			s_num = 0;
200 		}
201 		cb->args[1] = i;
202 	} else {
203 		if (req->sdiag_protocol >= MAX_LINKS)
204 			return -ENOENT;
205 
206 		err = __netlink_diag_dump(skb, cb, req->sdiag_protocol, s_num);
207 	}
208 
209 	return err < 0 ? err : skb->len;
210 }
211 
netlink_diag_dump_done(struct netlink_callback * cb)212 static int netlink_diag_dump_done(struct netlink_callback *cb)
213 {
214 	struct rhashtable_iter *hti = (void *)cb->args[2];
215 
216 	if (cb->args[0] == 1)
217 		rhashtable_walk_exit(hti);
218 
219 	kfree(hti);
220 
221 	return 0;
222 }
223 
netlink_diag_handler_dump(struct sk_buff * skb,struct nlmsghdr * h)224 static int netlink_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
225 {
226 	int hdrlen = sizeof(struct netlink_diag_req);
227 	struct net *net = sock_net(skb->sk);
228 
229 	if (nlmsg_len(h) < hdrlen)
230 		return -EINVAL;
231 
232 	if (h->nlmsg_flags & NLM_F_DUMP) {
233 		struct netlink_dump_control c = {
234 			.dump = netlink_diag_dump,
235 			.done = netlink_diag_dump_done,
236 		};
237 		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
238 	} else
239 		return -EOPNOTSUPP;
240 }
241 
242 static const struct sock_diag_handler netlink_diag_handler = {
243 	.family = AF_NETLINK,
244 	.dump = netlink_diag_handler_dump,
245 };
246 
netlink_diag_init(void)247 static int __init netlink_diag_init(void)
248 {
249 	return sock_diag_register(&netlink_diag_handler);
250 }
251 
netlink_diag_exit(void)252 static void __exit netlink_diag_exit(void)
253 {
254 	sock_diag_unregister(&netlink_diag_handler);
255 }
256 
257 module_init(netlink_diag_init);
258 module_exit(netlink_diag_exit);
259 MODULE_LICENSE("GPL");
260 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 16 /* AF_NETLINK */);
261