1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2016, Linaro Limited
4  */
5 
6 #include <sys/types.h>
7 #include <stdbool.h>
8 #include <arpa/inet.h>
9 #include <err.h>
10 #include <string.h>
11 #include <stdlib.h>
12 #include <errno.h>
13 #include <netdb.h>
14 #include <netinet/in.h>
15 #include <poll.h>
16 #include <sys/socket.h>
17 #include <unistd.h>
18 
19 #include "sock_server.h"
20 
21 struct server_state {
22 	struct sock_state *socks;
23 	struct pollfd *fds;
24 	nfds_t nfds;
25 	bool got_quit;
26 	struct sock_io_cb *cb;
27 };
28 
29 #define SOCK_BUF_SIZE	512
30 
31 struct sock_state {
32 	bool (*cb)(struct server_state *srvst, size_t idx);
33 	struct sock_server_bind *serv;
34 };
35 
server_io_cb(struct server_state * srvst,size_t idx)36 static bool server_io_cb(struct server_state *srvst, size_t idx)
37 {
38 	short revents = srvst->fds[idx].revents;
39 	short *events = &srvst->fds[idx].events;
40 	struct sock_io_cb *cb = srvst->cb;
41 	int fd = 0;
42 
43 	fd = srvst->fds[idx].fd;
44 	if (revents & POLLIN) {
45 		if (!cb->read)
46 			*events &= ~POLLIN;
47 		else if (!cb->read(cb->ptr, fd, events))
48 			goto close;
49 	}
50 
51 	if (revents & POLLOUT) {
52 		if (!cb->write)
53 			*events &= ~POLLOUT;
54 		else if (!cb->write(cb->ptr, fd, events))
55 			goto close;
56 	}
57 
58 	if (!(revents & ~(POLLIN | POLLOUT)))
59 		return true;
60 close:
61 	if (close(fd)) {
62 		warn("server_io_cb: close(%d)", fd);
63 		return false;
64 	}
65 	srvst->fds[idx].fd = -1;
66 	return true;
67 }
68 
server_add_state(struct server_state * srvst,bool (* cb)(struct server_state * srvst,size_t idx),struct sock_server_bind * serv,int fd,short poll_events)69 static bool server_add_state(struct server_state *srvst,
70 			     bool (*cb)(struct server_state *srvst, size_t idx),
71 			     struct sock_server_bind *serv, int fd,
72 			     short poll_events)
73 {
74 	void *p = NULL;
75 	size_t n = 0;
76 
77 	for (n = 0; n < srvst->nfds; n++) {
78 		if (srvst->fds[n].fd == -1) {
79 			srvst->socks[n].cb = cb;
80 			srvst->socks[n].serv = serv;
81 			srvst->fds[n].fd = fd;
82 			srvst->fds[n].events = poll_events;
83 			srvst->fds[n].revents = 0;
84 			return true;
85 		}
86 	}
87 
88 	p = realloc(srvst->socks, sizeof(*srvst->socks) * (srvst->nfds + 1));
89 	if (!p)
90 		return false;
91 	srvst->socks = p;
92 	srvst->socks[srvst->nfds].cb = cb;
93 	srvst->socks[srvst->nfds].serv = serv;
94 
95 	p = realloc(srvst->fds, sizeof(*srvst->fds) * (srvst->nfds + 1));
96 	if (!p)
97 		return false;
98 	srvst->fds = p;
99 	srvst->fds[srvst->nfds].fd = fd;
100 	srvst->fds[srvst->nfds].events = poll_events;
101 	srvst->fds[srvst->nfds].revents = 0;
102 
103 	srvst->nfds++;
104 	return true;
105 }
106 
tcp_server_accept_cb(struct server_state * srvst,size_t idx)107 static bool tcp_server_accept_cb(struct server_state *srvst, size_t idx)
108 {
109 	short revents = srvst->fds[idx].revents;
110 	struct sockaddr_storage sass = { };
111 	struct sockaddr *sa = (struct sockaddr *)&sass;
112 	socklen_t len = sizeof(sass);
113 	int fd = 0;
114 	short io_events = POLLIN | POLLOUT;
115 
116 	if (!(revents & POLLIN))
117 		return false;
118 
119 	fd = accept(srvst->fds[idx].fd, sa, &len);
120 	if (fd == -1) {
121 		if (errno == EAGAIN || errno == EWOULDBLOCK ||
122 		    errno == ECONNABORTED)
123 			return true;
124 		return false;
125 	}
126 
127 	if (srvst->cb->accept &&
128 	    !srvst->cb->accept(srvst->cb->ptr, fd, &io_events)) {
129 		if (close(fd))
130 			warn("server_accept_cb: close(%d)", fd);
131 		return true;
132 	}
133 
134 	return server_add_state(srvst, server_io_cb, srvst->socks[idx].serv,
135 				fd, io_events);
136 }
137 
udp_server_cb(struct server_state * srvst,size_t idx)138 static bool udp_server_cb(struct server_state *srvst, size_t idx)
139 {
140 	short revents = srvst->fds[idx].revents;
141 
142 	if (!(revents & POLLIN))
143 		return false;
144 
145 	return srvst->cb->accept(srvst->cb->ptr, srvst->fds[idx].fd, NULL);
146 }
147 
server_quit_cb(struct server_state * srvst,size_t idx)148 static bool server_quit_cb(struct server_state *srvst, size_t idx)
149 {
150 	(void)idx;
151 	srvst->got_quit = true;
152 	return true;
153 }
154 
sock_server(struct sock_server * ts,bool (* cb)(struct server_state * srvst,size_t idx))155 static void sock_server(struct sock_server *ts,
156 			bool (*cb)(struct server_state *srvst, size_t idx))
157 {
158 	struct server_state srvst = { .cb = ts->cb };
159 	int pres = 0;
160 	size_t n = 0;
161 	char b = 0;
162 
163 	sock_server_lock(ts);
164 
165 	for (n = 0; n < ts->num_binds; n++) {
166 		if (!server_add_state(&srvst, cb, ts->bind + n,
167 				      ts->bind[n].fd, POLLIN))
168 			goto bad;
169 	}
170 
171 	if (!server_add_state(&srvst, server_quit_cb, NULL,
172 			      ts->quit_fd, POLLIN))
173 		goto bad;
174 
175 	while (true) {
176 		sock_server_unlock(ts);
177 		/*
178 		 * First sleep 5 ms to make it easier to test send timeouts
179 		 * due to this rate limit.
180 		 */
181 		poll(NULL, 0, 5);
182 		pres = poll(srvst.fds, srvst.nfds, -1);
183 		sock_server_lock(ts);
184 		if (pres < 0)
185 			goto bad;
186 
187 		for (n = 0; pres && n < srvst.nfds; n++) {
188 			if (srvst.fds[n].revents) {
189 				pres--;
190 				if (!srvst.socks[n].cb(&srvst, n))
191 					goto bad;
192 			}
193 		}
194 
195 		if (srvst.got_quit)
196 			goto out;
197 	}
198 
199 bad:
200 	ts->error = true;
201 out:
202 	for (n = 0; n < srvst.nfds; n++) {
203 		/* Don't close accept and quit fds */
204 		if (srvst.fds[n].fd != -1 && srvst.socks[n].serv &&
205 		    srvst.fds[n].fd != srvst.socks[n].serv->fd) {
206 			if (close(srvst.fds[n].fd))
207 				warn("sock_server: close(%d)", srvst.fds[n].fd);
208 		}
209 	}
210 	free(srvst.socks);
211 	free(srvst.fds);
212 	if (read(ts->quit_fd, &b, 1) != 1)
213 		ts->error = true;
214 
215 	sock_server_unlock(ts);
216 }
217 
sock_server_stream(void * arg)218 static void *sock_server_stream(void *arg)
219 {
220 	sock_server(arg, tcp_server_accept_cb);
221 	return NULL;
222 }
223 
sock_server_dgram(void * arg)224 static void *sock_server_dgram(void *arg)
225 {
226 	sock_server(arg, udp_server_cb);
227 	return NULL;
228 }
229 
sock_server_add_fd(struct sock_server * ts,struct addrinfo * ai)230 static void sock_server_add_fd(struct sock_server *ts, struct addrinfo *ai)
231 {
232 	struct sock_server_bind serv = { };
233 	struct sockaddr_storage sass = { };
234 	struct sockaddr *sa = (struct sockaddr *)&sass;
235 	struct sockaddr_in *sain = (struct sockaddr_in *)&sass;
236 	struct sockaddr_in6 *sain6 = (struct sockaddr_in6 *)&sass;
237 	void *src = NULL;
238 	socklen_t len = sizeof(sass);
239 	struct sock_server_bind *p = NULL;
240 
241 	serv.fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
242 	if (serv.fd < 0)
243 		return;
244 
245 	if (bind(serv.fd, ai->ai_addr, ai->ai_addrlen))
246 		goto bad;
247 
248 	if (ai->ai_socktype == SOCK_STREAM && listen(serv.fd, 5))
249 		goto bad;
250 
251 	if (getsockname(serv.fd, sa, &len))
252 		goto bad;
253 
254 	switch (sa->sa_family) {
255 	case AF_INET:
256 		src = &sain->sin_addr;
257 		serv.port = ntohs(sain->sin_port);
258 		break;
259 	case AF_INET6:
260 		src = &sain6->sin6_addr;
261 		serv.port = ntohs(sain6->sin6_port);
262 	default:
263 		goto bad;
264 	}
265 
266 	if (!inet_ntop(sa->sa_family, src, serv.host, sizeof(serv.host)))
267 		goto bad;
268 
269 	p = realloc(ts->bind, sizeof(*p) * (ts->num_binds + 1));
270 	if (!p)
271 		goto bad;
272 
273 	ts->bind = p;
274 	p[ts->num_binds] = serv;
275 	ts->num_binds++;
276 	return;
277 bad:
278 	if (close(serv.fd))
279 		warn("sock_server_add_fd: close(%d)", serv.fd);
280 }
281 
sock_server_uninit(struct sock_server * ts)282 void sock_server_uninit(struct sock_server *ts)
283 {
284 	size_t n = 0;
285 	int e = 0;
286 
287 	if (ts->stop_fd != -1) {
288 		if (close(ts->stop_fd))
289 			warn("sock_server_uninit: close(%d)", ts->stop_fd);
290 		ts->stop_fd = -1;
291 		e = pthread_join(ts->thr, NULL);
292 		if (e)
293 			warnx("sock_server_uninit: pthread_join: %s",
294 			      strerror(e));
295 	}
296 
297 	e = pthread_mutex_destroy(&ts->mu);
298 	if (e)
299 		warnx("sock_server_uninit: pthread_mutex_destroy: %s",
300 		      strerror(e));
301 
302 	for (n = 0; n < ts->num_binds; n++)
303 		if (close(ts->bind[n].fd))
304 			warn("sock_server_uninit: close(%d)", ts->bind[n].fd);
305 	free(ts->bind);
306 	if (ts->quit_fd != -1 && close(ts->quit_fd))
307 		warn("sock_server_uninit: close(%d)", ts->quit_fd);
308 	memset(ts, 0, sizeof(*ts));
309 	ts->quit_fd = -1;
310 	ts->stop_fd = -1;
311 }
312 
sock_server_init(struct sock_server * ts,struct sock_io_cb * cb,int socktype)313 static bool sock_server_init(struct sock_server *ts, struct sock_io_cb *cb,
314 			     int socktype)
315 {
316 	struct addrinfo hints = { };
317 	struct addrinfo *ai = NULL;
318 	struct addrinfo *ai0 = NULL;
319 	int fd_pair[2] = { };
320 	int e = 0;
321 
322 	memset(ts, 0, sizeof(*ts));
323 	ts->quit_fd = -1;
324 	ts->stop_fd = -1;
325 	ts->cb = cb;
326 
327 	e = pthread_mutex_init(&ts->mu, NULL);
328 	if (e) {
329 		warnx("sock_server_init: pthread_mutex_init: %s", strerror(e));
330 		return false;
331 	}
332 
333 	hints.ai_flags = AI_PASSIVE;
334 	hints.ai_family = AF_UNSPEC;
335 	hints.ai_socktype = socktype;
336 
337 	if (getaddrinfo(NULL, "0", &hints, &ai0))
338 		return false;
339 
340 	for (ai = ai0; ai; ai = ai->ai_next)
341 		sock_server_add_fd(ts, ai);
342 
343 	freeaddrinfo(ai0);
344 
345 	if (!ts->num_binds)
346 		return false;
347 
348 	if (pipe(fd_pair)) {
349 		sock_server_uninit(ts);
350 		return false;
351 	}
352 
353 	ts->quit_fd = fd_pair[0];
354 
355 	if (socktype == SOCK_STREAM)
356 		e = pthread_create(&ts->thr, NULL, sock_server_stream, ts);
357 	else
358 		e = pthread_create(&ts->thr, NULL, sock_server_dgram, ts);
359 	if (e) {
360 		warnx("sock_server_init: pthread_create: %s", strerror(e));
361 		if (close(fd_pair[1]))
362 			warn("sock_server_init: close(%d)", fd_pair[1]);
363 		sock_server_uninit(ts);
364 		return false;
365 	}
366 
367 	ts->stop_fd = fd_pair[1];
368 	return true;
369 }
370 
sock_server_init_tcp(struct sock_server * sock_serv,struct sock_io_cb * cb)371 bool sock_server_init_tcp(struct sock_server *sock_serv, struct sock_io_cb *cb)
372 {
373 	return sock_server_init(sock_serv, cb, SOCK_STREAM);
374 }
375 
sock_server_init_udp(struct sock_server * sock_serv,struct sock_io_cb * cb)376 bool sock_server_init_udp(struct sock_server *sock_serv, struct sock_io_cb *cb)
377 {
378 	return sock_server_init(sock_serv, cb, SOCK_DGRAM);
379 }
380 
sock_server_lock(struct sock_server * ts)381 void sock_server_lock(struct sock_server *ts)
382 {
383 	int e = pthread_mutex_lock(&ts->mu);
384 
385 	if (e)
386 		errx(1, "sock_server_lock: pthread_mutex_lock: %s", strerror(e));
387 }
388 
sock_server_unlock(struct sock_server * ts)389 void sock_server_unlock(struct sock_server *ts)
390 {
391 	int e = pthread_mutex_unlock(&ts->mu);
392 
393 	if (e)
394 		errx(1, "sock_server_unlock: pthread_mutex_unlock: %s",
395 		     strerror(e));
396 }
397