xref: /xnu-12377.81.4/bsd/kern/kpi_socket.c (revision 043036a2b3718f7f0be807e2870f8f47d3fa0796)
1 /*
2  * Copyright (c) 2003-2021 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #define __KPI__
30 #include <sys/systm.h>
31 #include <sys/kernel.h>
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 #include <sys/socketvar.h>
35 #include <sys/param.h>
36 #include <sys/proc.h>
37 #include <sys/errno.h>
38 #include <sys/malloc.h>
39 #include <sys/protosw.h>
40 #include <sys/domain.h>
41 #include <sys/mbuf.h>
42 #include <sys/mcache.h>
43 #include <sys/fcntl.h>
44 #include <sys/filio.h>
45 #include <sys/uio_internal.h>
46 #include <kern/locks.h>
47 #include <net/net_api_stats.h>
48 #include <netinet/in.h>
49 #include <libkern/OSAtomic.h>
50 #include <stdbool.h>
51 #include <net/sockaddr_utils.h>
52 
53 #if SKYWALK
54 #include <skywalk/core/skywalk_var.h>
55 #endif /* SKYWALK */
56 
57 #define SOCK_SEND_MBUF_MODE_VERBOSE     0x0001
58 
59 static errno_t sock_send_internal(socket_t, const struct msghdr *,
60     mbuf_t, int, size_t *);
61 
62 #undef sock_accept
63 #undef sock_socket
64 errno_t sock_accept(socket_t so, struct sockaddr *__sized_by(fromlen) from, int fromlen,
65     int flags, sock_upcall callback, void *cookie, socket_t *new_so);
66 errno_t sock_socket(int domain, int type, int protocol, sock_upcall callback,
67     void *context, socket_t *new_so);
68 
69 static errno_t sock_accept_common(socket_t sock, struct sockaddr *__sized_by(fromlen) from,
70     int fromlen, int flags, sock_upcall callback, void *cookie,
71     socket_t *new_sock, bool is_internal);
72 static errno_t sock_socket_common(int domain, int type, int protocol,
73     sock_upcall callback, void *context, socket_t *new_so, bool is_internal);
74 
75 errno_t
sock_accept_common(socket_t sock,struct sockaddr * __sized_by (fromlen)from,int fromlen,int flags,sock_upcall callback,void * cookie,socket_t * new_sock,bool is_internal)76 sock_accept_common(socket_t sock, struct sockaddr *__sized_by(fromlen) from, int fromlen, int flags,
77     sock_upcall callback, void *cookie, socket_t *new_sock, bool is_internal)
78 {
79 	struct sockaddr *__single sa;
80 	struct socket *new_so;
81 	lck_mtx_t *mutex_held;
82 	int dosocklock;
83 	errno_t error = 0;
84 
85 	if (sock == NULL || new_sock == NULL) {
86 		return EINVAL;
87 	}
88 
89 	socket_lock(sock, 1);
90 	if ((sock->so_options & SO_ACCEPTCONN) == 0) {
91 		socket_unlock(sock, 1);
92 		return EINVAL;
93 	}
94 	if ((flags & ~(MSG_DONTWAIT)) != 0) {
95 		socket_unlock(sock, 1);
96 		return ENOTSUP;
97 	}
98 check_again:
99 	if (((flags & MSG_DONTWAIT) != 0 || (sock->so_state & SS_NBIO) != 0) &&
100 	    sock->so_comp.tqh_first == NULL) {
101 		socket_unlock(sock, 1);
102 		return EWOULDBLOCK;
103 	}
104 
105 	if (sock->so_proto->pr_getlock != NULL) {
106 		mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
107 		dosocklock = 1;
108 	} else {
109 		mutex_held = sock->so_proto->pr_domain->dom_mtx;
110 		dosocklock = 0;
111 	}
112 
113 	while (TAILQ_EMPTY(&sock->so_comp) && sock->so_error == 0) {
114 		if (sock->so_state & SS_CANTRCVMORE) {
115 			sock->so_error = ECONNABORTED;
116 			break;
117 		}
118 		error = msleep((caddr_t)&sock->so_timeo, mutex_held,
119 		    PSOCK | PCATCH, "sock_accept", NULL);
120 		if (error != 0) {
121 			socket_unlock(sock, 1);
122 			return error;
123 		}
124 	}
125 	if (sock->so_error != 0) {
126 		error = sock->so_error;
127 		sock->so_error = 0;
128 		socket_unlock(sock, 1);
129 		return error;
130 	}
131 
132 	so_acquire_accept_list(sock, NULL);
133 	if (TAILQ_EMPTY(&sock->so_comp)) {
134 		so_release_accept_list(sock);
135 		goto check_again;
136 	}
137 	new_so = TAILQ_FIRST(&sock->so_comp);
138 	TAILQ_REMOVE(&sock->so_comp, new_so, so_list);
139 	new_so->so_state &= ~SS_COMP;
140 	new_so->so_head = NULL;
141 	sock->so_qlen--;
142 
143 	so_release_accept_list(sock);
144 
145 	/*
146 	 * Count the accepted socket as an in-kernel socket
147 	 */
148 	new_so->so_flags1 |= SOF1_IN_KERNEL_SOCKET;
149 	INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_total);
150 	if (is_internal) {
151 		INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_os_total);
152 	}
153 
154 	/*
155 	 * Pass the pre-accepted socket to any interested socket filter(s).
156 	 * Upon failure, the socket would have been closed by the callee.
157 	 */
158 	if (new_so->so_filt != NULL) {
159 		/*
160 		 * Temporarily drop the listening socket's lock before we
161 		 * hand off control over to the socket filter(s), but keep
162 		 * a reference so that it won't go away.  We'll grab it
163 		 * again once we're done with the filter(s).
164 		 */
165 		socket_unlock(sock, 0);
166 		if ((error = soacceptfilter(new_so, sock)) != 0) {
167 			/* Drop reference on listening socket */
168 			sodereference(sock);
169 			return error;
170 		}
171 		socket_lock(sock, 0);
172 	}
173 
174 	if (dosocklock) {
175 		LCK_MTX_ASSERT(new_so->so_proto->pr_getlock(new_so, 0),
176 		    LCK_MTX_ASSERT_NOTOWNED);
177 		socket_lock(new_so, 1);
178 	}
179 
180 	(void) soacceptlock(new_so, &sa, 0);
181 
182 	socket_unlock(sock, 1); /* release the head */
183 
184 	/* see comments in sock_setupcall() */
185 	if (callback != NULL) {
186 #if defined(__arm64__)
187 		sock_setupcalls_locked(new_so, callback, cookie, callback, cookie, 0);
188 #else /* defined(__arm64__) */
189 		sock_setupcalls_locked(new_so, callback, cookie, NULL, NULL, 0);
190 #endif /* defined(__arm64__) */
191 	}
192 
193 	if (sa != NULL && from != NULL) {
194 		SOCKADDR_COPY(sa, from, MIN(fromlen, sa->sa_len));
195 	}
196 	free_sockaddr(sa);
197 
198 	/*
199 	 * If the socket has been marked as inactive by sosetdefunct(),
200 	 * disallow further operations on it.
201 	 */
202 	if (new_so->so_flags & SOF_DEFUNCT) {
203 		(void) sodefunct(current_proc(), new_so,
204 		    SHUTDOWN_SOCKET_LEVEL_DISCONNECT_INTERNAL);
205 	}
206 	*new_sock = new_so;
207 	if (dosocklock) {
208 		socket_unlock(new_so, 1);
209 	}
210 	return error;
211 }
212 
213 errno_t
sock_accept(socket_t sock,struct sockaddr * __sized_by (fromlen)from,int fromlen,int flags,sock_upcall callback,void * cookie,socket_t * new_sock)214 sock_accept(socket_t sock, struct sockaddr *__sized_by(fromlen) from, int fromlen, int flags,
215     sock_upcall callback, void *cookie, socket_t *new_sock)
216 {
217 	return sock_accept_common(sock, from, fromlen, flags,
218 	           callback, cookie, new_sock, false);
219 }
220 
221 errno_t
sock_accept_internal(socket_t sock,struct sockaddr * __sized_by (fromlen)from,int fromlen,int flags,sock_upcall callback,void * cookie,socket_t * new_sock)222 sock_accept_internal(socket_t sock, struct sockaddr *__sized_by(fromlen) from, int fromlen, int flags,
223     sock_upcall callback, void *cookie, socket_t *new_sock)
224 {
225 	return sock_accept_common(sock, from, fromlen, flags,
226 	           callback, cookie, new_sock, true);
227 }
228 
229 errno_t
sock_bind(socket_t sock,const struct sockaddr * to)230 sock_bind(socket_t sock, const struct sockaddr *to)
231 {
232 	int error = 0;
233 	struct sockaddr *sa = NULL;
234 	struct sockaddr_storage ss;
235 
236 	if (sock == NULL || to == NULL) {
237 		return EINVAL;
238 	}
239 
240 	if (to->sa_len > sizeof(ss)) {
241 		sa = kalloc_data(to->sa_len, Z_WAITOK | Z_ZERO | Z_NOFAIL);
242 	} else {
243 		sa = (struct sockaddr *)&ss;
244 	}
245 	SOCKADDR_COPY(to, sa, to->sa_len);
246 
247 	error = sobindlock(sock, sa, 1);        /* will lock socket */
248 
249 	if (sa != (struct sockaddr *)&ss) {
250 		kfree_data(sa, sa->sa_len);
251 	}
252 
253 	return error;
254 }
255 
256 errno_t
sock_connect(socket_t sock,const struct sockaddr * to,int flags)257 sock_connect(socket_t sock, const struct sockaddr *to, int flags)
258 {
259 	int error = 0;
260 	lck_mtx_t *mutex_held;
261 	struct sockaddr *sa = NULL;
262 	struct sockaddr_storage ss;
263 
264 	if (sock == NULL || to == NULL) {
265 		return EINVAL;
266 	}
267 
268 	if (to->sa_len > sizeof(ss)) {
269 		sa = kalloc_data(to->sa_len,
270 		    (flags & MSG_DONTWAIT) ? Z_NOWAIT : Z_WAITOK);
271 		if (sa == NULL) {
272 			return ENOBUFS;
273 		}
274 	} else {
275 		sa = (struct sockaddr *)&ss;
276 	}
277 	SOCKADDR_COPY(to, sa, to->sa_len);
278 
279 	socket_lock(sock, 1);
280 
281 	if ((sock->so_state & SS_ISCONNECTING) &&
282 	    ((sock->so_state & SS_NBIO) != 0 || (flags & MSG_DONTWAIT) != 0)) {
283 		error = EALREADY;
284 		goto out;
285 	}
286 
287 #if SKYWALK
288 	sk_protect_t protect = sk_async_transmit_protect();
289 #endif /* SKYWALK */
290 
291 	error = soconnectlock(sock, sa, 0);
292 
293 #if SKYWALK
294 	sk_async_transmit_unprotect(protect);
295 #endif /* SKYWALK */
296 
297 	if (!error) {
298 		if ((sock->so_state & SS_ISCONNECTING) &&
299 		    ((sock->so_state & SS_NBIO) != 0 ||
300 		    (flags & MSG_DONTWAIT) != 0)) {
301 			error = EINPROGRESS;
302 			goto out;
303 		}
304 
305 		if (sock->so_proto->pr_getlock != NULL) {
306 			mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
307 		} else {
308 			mutex_held = sock->so_proto->pr_domain->dom_mtx;
309 		}
310 
311 		while ((sock->so_state & SS_ISCONNECTING) &&
312 		    sock->so_error == 0) {
313 			error = msleep((caddr_t)&sock->so_timeo,
314 			    mutex_held, PSOCK | PCATCH, "sock_connect", NULL);
315 			if (error != 0) {
316 				break;
317 			}
318 		}
319 
320 		if (error == 0) {
321 			error = sock->so_error;
322 			sock->so_error = 0;
323 		}
324 	} else {
325 		sock->so_state &= ~SS_ISCONNECTING;
326 	}
327 out:
328 	socket_unlock(sock, 1);
329 
330 	if (sa != (struct sockaddr *)&ss) {
331 		kfree_data(sa, sa->sa_len);
332 	}
333 
334 	return error;
335 }
336 
337 errno_t
sock_connectwait(socket_t sock,const struct timeval * tv)338 sock_connectwait(socket_t sock, const struct timeval *tv)
339 {
340 	lck_mtx_t *mutex_held;
341 	errno_t retval = 0;
342 	struct timespec ts;
343 
344 	socket_lock(sock, 1);
345 
346 	/* Check if we're already connected or if we've already errored out */
347 	if ((sock->so_state & SS_ISCONNECTING) == 0 || sock->so_error != 0) {
348 		if (sock->so_error != 0) {
349 			retval = sock->so_error;
350 			sock->so_error = 0;
351 		} else {
352 			if ((sock->so_state & SS_ISCONNECTED) != 0) {
353 				retval = 0;
354 			} else {
355 				retval = EINVAL;
356 			}
357 		}
358 		goto done;
359 	}
360 
361 	/* copied translation from timeval to hertz from SO_RCVTIMEO handling */
362 	if (tv->tv_sec < 0 || tv->tv_sec > SHRT_MAX / hz ||
363 	    tv->tv_usec < 0 || tv->tv_usec >= 1000000) {
364 		retval = EDOM;
365 		goto done;
366 	}
367 
368 	ts.tv_sec = tv->tv_sec;
369 	ts.tv_nsec = (tv->tv_usec * (integer_t)NSEC_PER_USEC);
370 	if ((ts.tv_sec + (ts.tv_nsec / (long)NSEC_PER_SEC)) / 100 > SHRT_MAX) {
371 		retval = EDOM;
372 		goto done;
373 	}
374 
375 	if (sock->so_proto->pr_getlock != NULL) {
376 		mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
377 	} else {
378 		mutex_held = sock->so_proto->pr_domain->dom_mtx;
379 	}
380 
381 	msleep((caddr_t)&sock->so_timeo, mutex_held,
382 	    PSOCK, "sock_connectwait", &ts);
383 
384 	/* Check if we're still waiting to connect */
385 	if ((sock->so_state & SS_ISCONNECTING) && sock->so_error == 0) {
386 		retval = EINPROGRESS;
387 		goto done;
388 	}
389 
390 	if (sock->so_error != 0) {
391 		retval = sock->so_error;
392 		sock->so_error = 0;
393 	}
394 
395 done:
396 	socket_unlock(sock, 1);
397 	return retval;
398 }
399 
400 errno_t
sock_nointerrupt(socket_t sock,int on)401 sock_nointerrupt(socket_t sock, int on)
402 {
403 	socket_lock(sock, 1);
404 
405 	if (on) {
406 		sock->so_rcv.sb_flags |= SB_NOINTR;     /* This isn't safe */
407 		sock->so_snd.sb_flags |= SB_NOINTR;     /* This isn't safe */
408 	} else {
409 		sock->so_rcv.sb_flags &= ~SB_NOINTR;    /* This isn't safe */
410 		sock->so_snd.sb_flags &= ~SB_NOINTR;    /* This isn't safe */
411 	}
412 
413 	socket_unlock(sock, 1);
414 
415 	return 0;
416 }
417 
418 errno_t
sock_getpeername(socket_t sock,struct sockaddr * __sized_by (peernamelen)peername,int peernamelen)419 sock_getpeername(socket_t sock, struct sockaddr *__sized_by(peernamelen) peername, int peernamelen)
420 {
421 	int error;
422 	struct sockaddr *__single sa = NULL;
423 
424 	if (sock == NULL || peername == NULL || peernamelen < 0) {
425 		return EINVAL;
426 	}
427 
428 	socket_lock(sock, 1);
429 	if (!(sock->so_state & (SS_ISCONNECTED | SS_ISCONFIRMING))) {
430 		socket_unlock(sock, 1);
431 		return ENOTCONN;
432 	}
433 	error = sogetaddr_locked(sock, &sa, 1);
434 	socket_unlock(sock, 1);
435 	if (error == 0) {
436 		SOCKADDR_COPY(sa, peername, MIN(peernamelen, sa->sa_len));
437 		free_sockaddr(sa);
438 	}
439 	return error;
440 }
441 
442 errno_t
sock_getsockname(socket_t sock,struct sockaddr * __sized_by (socknamelen)sockname,int socknamelen)443 sock_getsockname(socket_t sock, struct sockaddr *__sized_by(socknamelen) sockname, int socknamelen)
444 {
445 	int error;
446 	struct sockaddr *__single sa = NULL;
447 
448 	if (sock == NULL || sockname == NULL || socknamelen < 0) {
449 		return EINVAL;
450 	}
451 
452 	socket_lock(sock, 1);
453 	error = sogetaddr_locked(sock, &sa, 0);
454 	socket_unlock(sock, 1);
455 	if (error == 0) {
456 		SOCKADDR_COPY(sa, sockname, MIN(socknamelen, sa->sa_len));
457 		free_sockaddr(sa);
458 	}
459 	return error;
460 }
461 
462 __private_extern__ int
sogetaddr_locked(struct socket * so,struct sockaddr ** psa,int peer)463 sogetaddr_locked(struct socket *so, struct sockaddr **psa, int peer)
464 {
465 	int error;
466 
467 	if (so == NULL || psa == NULL) {
468 		return EINVAL;
469 	}
470 
471 	*psa = NULL;
472 	error = peer ? so->so_proto->pr_usrreqs->pru_peeraddr(so, psa) :
473 	    so->so_proto->pr_usrreqs->pru_sockaddr(so, psa);
474 
475 	if (error == 0 && *psa == NULL) {
476 		error = ENOMEM;
477 	} else if (error != 0) {
478 		free_sockaddr(*psa);
479 	}
480 	return error;
481 }
482 
483 errno_t
sock_getaddr(socket_t sock,struct sockaddr ** psa,int peer)484 sock_getaddr(socket_t sock, struct sockaddr **psa, int peer)
485 {
486 	int error;
487 
488 	if (sock == NULL || psa == NULL) {
489 		return EINVAL;
490 	}
491 
492 	socket_lock(sock, 1);
493 	error = sogetaddr_locked(sock, psa, peer);
494 	socket_unlock(sock, 1);
495 
496 	return error;
497 }
498 
499 void
sock_freeaddr(struct sockaddr * sa)500 sock_freeaddr(struct sockaddr *sa)
501 {
502 	free_sockaddr(sa);
503 }
504 
505 errno_t
sock_getsockopt(socket_t sock,int level,int optname,void * optval,int * optlen)506 sock_getsockopt(socket_t sock, int level, int optname, void *optval,
507     int *optlen)
508 {
509 	int error = 0;
510 	struct sockopt  sopt;
511 
512 	if (sock == NULL || optval == NULL || optlen == NULL) {
513 		return EINVAL;
514 	}
515 
516 	sopt.sopt_dir = SOPT_GET;
517 	sopt.sopt_level = level;
518 	sopt.sopt_name = optname;
519 	sopt.sopt_val = CAST_USER_ADDR_T(optval);
520 	sopt.sopt_valsize = *optlen;
521 	sopt.sopt_p = kernproc;
522 	error = sogetoptlock(sock, &sopt, 1);   /* will lock socket */
523 	if (error == 0) {
524 		*optlen = (uint32_t)sopt.sopt_valsize;
525 	}
526 	return error;
527 }
528 
529 errno_t
sock_ioctl(socket_t sock,unsigned long request,void * __sized_by (IOCPARM_LEN (request))argp)530 sock_ioctl(socket_t sock, unsigned long request, void *__sized_by(IOCPARM_LEN(request)) argp)
531 {
532 	return soioctl(sock, request, argp, kernproc); /* will lock socket */
533 }
534 
535 errno_t
sock_setsockopt(socket_t sock,int level,int optname,const void * optval,int optlen)536 sock_setsockopt(socket_t sock, int level, int optname, const void *optval,
537     int optlen)
538 {
539 	struct sockopt  sopt;
540 
541 	if (sock == NULL || optval == NULL) {
542 		return EINVAL;
543 	}
544 
545 	sopt.sopt_dir = SOPT_SET;
546 	sopt.sopt_level = level;
547 	sopt.sopt_name = optname;
548 	sopt.sopt_val = CAST_USER_ADDR_T(optval);
549 	sopt.sopt_valsize = optlen;
550 	sopt.sopt_p = kernproc;
551 	return sosetoptlock(sock, &sopt, 1); /* will lock socket */
552 }
553 
554 /*
555  * This follows the recommended mappings between DSCP code points
556  * and WMM access classes.
557  */
558 static uint32_t
so_tc_from_dscp(uint8_t dscp)559 so_tc_from_dscp(uint8_t dscp)
560 {
561 	uint32_t tc;
562 
563 	if (dscp >= 0x30 && dscp <= 0x3f) {
564 		tc = SO_TC_VO;
565 	} else if (dscp >= 0x20 && dscp <= 0x2f) {
566 		tc = SO_TC_VI;
567 	} else if (dscp >= 0x08 && dscp <= 0x17) {
568 		tc = SO_TC_BK_SYS;
569 	} else {
570 		tc = SO_TC_BE;
571 	}
572 
573 	return tc;
574 }
575 
576 errno_t
sock_settclassopt(socket_t sock,const void * optval,size_t optlen)577 sock_settclassopt(socket_t sock, const void *optval, size_t optlen)
578 {
579 	errno_t error = 0;
580 	struct sockopt sopt;
581 	int sotc;
582 
583 	if (sock == NULL || optval == NULL || optlen != sizeof(int)) {
584 		return EINVAL;
585 	}
586 
587 	socket_lock(sock, 1);
588 	if (!(sock->so_state & SS_ISCONNECTED)) {
589 		/*
590 		 * If the socket is not connected then we don't know
591 		 * if the destination is on LAN  or not. Skip
592 		 * setting traffic class in this case
593 		 */
594 		error = ENOTCONN;
595 		goto out;
596 	}
597 
598 	if (sock->so_proto == NULL || sock->so_proto->pr_domain == NULL ||
599 	    sock->so_pcb == NULL) {
600 		error = EINVAL;
601 		goto out;
602 	}
603 
604 	/*
605 	 * Set the socket traffic class based on the passed DSCP code point
606 	 * regardless of the scope of the destination
607 	 */
608 	sotc = so_tc_from_dscp((uint8_t)((*(const int *)optval) >> 2));
609 
610 	sopt.sopt_dir = SOPT_SET;
611 	sopt.sopt_val = CAST_USER_ADDR_T(&sotc);
612 	sopt.sopt_valsize = sizeof(sotc);
613 	sopt.sopt_p = kernproc;
614 	sopt.sopt_level = SOL_SOCKET;
615 	sopt.sopt_name = SO_TRAFFIC_CLASS;
616 
617 	error = sosetoptlock(sock, &sopt, 0);   /* already locked */
618 
619 	if (error != 0) {
620 		printf("%s: sosetopt SO_TRAFFIC_CLASS failed %d\n",
621 		    __func__, error);
622 		goto out;
623 	}
624 
625 	/*
626 	 * Check if the destination address is LAN or link local address.
627 	 * We do not want to set traffic class bits if the destination
628 	 * is not local.
629 	 */
630 	if (!so_isdstlocal(sock)) {
631 		goto out;
632 	}
633 
634 	sopt.sopt_dir = SOPT_SET;
635 	sopt.sopt_val = CAST_USER_ADDR_T(optval);
636 	sopt.sopt_valsize = optlen;
637 	sopt.sopt_p = kernproc;
638 
639 	switch (SOCK_DOM(sock)) {
640 	case PF_INET:
641 		sopt.sopt_level = IPPROTO_IP;
642 		sopt.sopt_name = IP_TOS;
643 		break;
644 	case PF_INET6:
645 		sopt.sopt_level = IPPROTO_IPV6;
646 		sopt.sopt_name = IPV6_TCLASS;
647 		break;
648 	default:
649 		error = EINVAL;
650 		goto out;
651 	}
652 
653 	error = sosetoptlock(sock, &sopt, 0);   /* already locked */
654 	socket_unlock(sock, 1);
655 	return error;
656 out:
657 	socket_unlock(sock, 1);
658 	return error;
659 }
660 
661 errno_t
sock_gettclassopt(socket_t sock,void * optval,size_t * optlen)662 sock_gettclassopt(socket_t sock, void *optval, size_t *optlen)
663 {
664 	errno_t error = 0;
665 	struct sockopt sopt;
666 
667 	if (sock == NULL || optval == NULL || optlen == NULL) {
668 		return EINVAL;
669 	}
670 
671 	sopt.sopt_dir = SOPT_GET;
672 	sopt.sopt_val = CAST_USER_ADDR_T(optval);
673 	sopt.sopt_valsize = *optlen;
674 	sopt.sopt_p = kernproc;
675 
676 	socket_lock(sock, 1);
677 	if (sock->so_proto == NULL || sock->so_proto->pr_domain == NULL) {
678 		socket_unlock(sock, 1);
679 		return EINVAL;
680 	}
681 
682 	switch (SOCK_DOM(sock)) {
683 	case PF_INET:
684 		sopt.sopt_level = IPPROTO_IP;
685 		sopt.sopt_name = IP_TOS;
686 		break;
687 	case PF_INET6:
688 		sopt.sopt_level = IPPROTO_IPV6;
689 		sopt.sopt_name = IPV6_TCLASS;
690 		break;
691 	default:
692 		socket_unlock(sock, 1);
693 		return EINVAL;
694 	}
695 	error = sogetoptlock(sock, &sopt, 0);   /* already locked */
696 	socket_unlock(sock, 1);
697 	if (error == 0) {
698 		*optlen = sopt.sopt_valsize;
699 	}
700 	return error;
701 }
702 
703 errno_t
sock_listen(socket_t sock,int backlog)704 sock_listen(socket_t sock, int backlog)
705 {
706 	if (sock == NULL) {
707 		return EINVAL;
708 	}
709 
710 	return solisten(sock, backlog); /* will lock socket */
711 }
712 
713 errno_t
sock_receive_internal(socket_t sock,struct msghdr * msg,mbuf_t * data,int flags,size_t * recvdlen)714 sock_receive_internal(socket_t sock, struct msghdr *msg, mbuf_t *data,
715     int flags, size_t *recvdlen)
716 {
717 	uio_t auio;
718 	mbuf_ref_t control = NULL;
719 	int error = 0;
720 	user_ssize_t length = 0;
721 	struct sockaddr *__single fromsa = NULL;
722 	UIO_STACKBUF(uio_buf, (msg != NULL) ? msg->msg_iovlen : 0);
723 
724 	if (sock == NULL) {
725 		return EINVAL;
726 	}
727 
728 	auio = uio_createwithbuffer(((msg != NULL) ? msg->msg_iovlen : 0),
729 	    0, UIO_SYSSPACE, UIO_READ, &uio_buf[0], sizeof(uio_buf));
730 	if (msg != NULL && data == NULL) {
731 		int i;
732 		struct iovec *tempp = __unsafe_forge_bidi_indexable(struct iovec *,
733 		    msg->msg_iov,
734 		    sizeof(struct iovec) * msg->msg_iovlen);
735 
736 		for (i = 0; i < msg->msg_iovlen; i++) {
737 			uio_addiov(auio,
738 			    CAST_USER_ADDR_T((tempp + i)->iov_base),
739 			    (tempp + i)->iov_len);
740 		}
741 		if (uio_resid(auio) < 0) {
742 			return EINVAL;
743 		}
744 	} else if (recvdlen != NULL) {
745 		uio_setresid(auio, (uio_resid(auio) + *recvdlen));
746 	}
747 	length = uio_resid(auio);
748 
749 	if (recvdlen != NULL) {
750 		*recvdlen = 0;
751 	}
752 
753 	/* let pru_soreceive handle the socket locking */
754 	error = sock->so_proto->pr_usrreqs->pru_soreceive(sock, &fromsa, auio,
755 	    data, (msg && msg->msg_control) ? &control : NULL, &flags);
756 	if (error != 0) {
757 		goto cleanup;
758 	}
759 
760 	if (recvdlen != NULL) {
761 		*recvdlen = length - uio_resid(auio);
762 	}
763 	if (msg != NULL) {
764 		msg->msg_flags = flags;
765 
766 		if (msg->msg_name != NULL) {
767 			int salen;
768 			salen = msg->msg_namelen;
769 			if (msg->msg_namelen > 0 && fromsa != NULL) {
770 				salen = MIN(salen, fromsa->sa_len);
771 				SOCKADDR_COPY(fromsa, msg->msg_name,
772 				    msg->msg_namelen > fromsa->sa_len ?
773 				    fromsa->sa_len : msg->msg_namelen);
774 			}
775 		}
776 
777 		if (msg->msg_control != NULL) {
778 			struct mbuf *m = control;
779 			int clen = msg->msg_controllen;
780 			u_char *original_ctl = msg->msg_control;
781 			u_char *ctlbuf = msg->msg_control;
782 
783 			msg->msg_control = NULL;
784 			msg->msg_controllen = 0;
785 
786 			while (m != NULL && clen > 0) {
787 				unsigned int tocopy;
788 
789 				if (clen >= m->m_len) {
790 					tocopy = m->m_len;
791 				} else {
792 					msg->msg_flags |= MSG_CTRUNC;
793 					tocopy = clen;
794 				}
795 				memcpy(ctlbuf, mtod(m, caddr_t), tocopy);
796 				ctlbuf += tocopy;
797 				clen -= tocopy;
798 				m = m->m_next;
799 			}
800 			msg->msg_control = original_ctl;
801 			msg->msg_controllen = (socklen_t)(ctlbuf - original_ctl);
802 		}
803 	}
804 
805 cleanup:
806 	if (control != NULL) {
807 		m_freem(control);
808 	}
809 	free_sockaddr(fromsa);
810 	return error;
811 }
812 
813 errno_t
sock_receive(socket_t sock,struct msghdr * msg,int flags,size_t * recvdlen)814 sock_receive(socket_t sock, struct msghdr *msg, int flags, size_t *recvdlen)
815 {
816 	if ((msg == NULL) || (msg->msg_iovlen < 1) ||
817 	    (msg->msg_iov[0].iov_len == 0) ||
818 	    (msg->msg_iov[0].iov_base == NULL)) {
819 		return EINVAL;
820 	}
821 
822 	return sock_receive_internal(sock, msg, NULL, flags, recvdlen);
823 }
824 
825 errno_t
sock_receivembuf(socket_t sock,struct msghdr * msg,mbuf_t * data,int flags,size_t * recvlen)826 sock_receivembuf(socket_t sock, struct msghdr *msg, mbuf_t *data, int flags,
827     size_t *recvlen)
828 {
829 	if (data == NULL || recvlen == 0 || *recvlen <= 0 || (msg != NULL &&
830 	    (msg->msg_iov != NULL || msg->msg_iovlen != 0))) {
831 		return EINVAL;
832 	}
833 
834 	return sock_receive_internal(sock, msg, data, flags, recvlen);
835 }
836 
837 errno_t
sock_send_internal(socket_t sock,const struct msghdr * msg,mbuf_t data,int flags,size_t * sentlen)838 sock_send_internal(socket_t sock, const struct msghdr *msg, mbuf_t data,
839     int flags, size_t *sentlen)
840 {
841 	uio_t auio = NULL;
842 	mbuf_ref_t control = NULL;
843 	int error = 0;
844 	user_ssize_t datalen = 0;
845 
846 	if (sock == NULL) {
847 		error = EINVAL;
848 		goto errorout;
849 	}
850 
851 	if (data == NULL && msg != NULL) {
852 		struct iovec *tempp = __unsafe_forge_bidi_indexable(struct iovec *,
853 		    msg->msg_iov,
854 		    sizeof(struct iovec) * msg->msg_iovlen);
855 
856 		auio = uio_create(msg->msg_iovlen, 0, UIO_SYSSPACE, UIO_WRITE);
857 		if (auio == NULL) {
858 #if (DEBUG || DEVELOPMENT)
859 			printf("sock_send_internal: so %p uio_createwithbuffer(%lu) failed, ENOMEM\n",
860 			    sock, UIO_SIZEOF(msg->msg_iovlen));
861 #endif /* (DEBUG || DEVELOPMENT) */
862 			error = ENOMEM;
863 			goto errorout;
864 		}
865 		if (tempp != NULL) {
866 			int i;
867 
868 			for (i = 0; i < msg->msg_iovlen; i++) {
869 				uio_addiov(auio,
870 				    CAST_USER_ADDR_T((tempp + i)->iov_base),
871 				    (tempp + i)->iov_len);
872 			}
873 
874 			if (uio_resid(auio) < 0) {
875 				error = EINVAL;
876 				goto errorout;
877 			}
878 		}
879 	}
880 
881 	if (sentlen != NULL) {
882 		*sentlen = 0;
883 	}
884 
885 	if (auio != NULL) {
886 		datalen = uio_resid(auio);
887 	} else {
888 		datalen = data->m_pkthdr.len;
889 	}
890 
891 	if (msg != NULL && msg->msg_control) {
892 		if ((size_t)msg->msg_controllen < sizeof(struct cmsghdr)) {
893 			error = EINVAL;
894 			goto errorout;
895 		}
896 
897 		if ((size_t)msg->msg_controllen > MLEN) {
898 			error = EINVAL;
899 			goto errorout;
900 		}
901 
902 		control = m_get(M_NOWAIT, MT_CONTROL);
903 		if (control == NULL) {
904 			error = ENOMEM;
905 			goto errorout;
906 		}
907 		memcpy(mtod(control, caddr_t), msg->msg_control,
908 		    msg->msg_controllen);
909 		control->m_len = msg->msg_controllen;
910 	}
911 
912 #if SKYWALK
913 	sk_protect_t protect = sk_async_transmit_protect();
914 #endif /* SKYWALK */
915 
916 	error = sock->so_proto->pr_usrreqs->pru_sosend(sock, msg != NULL ?
917 	    (struct sockaddr *)msg->msg_name : NULL, auio, data,
918 	    control, flags);
919 
920 #if SKYWALK
921 	sk_async_transmit_unprotect(protect);
922 #endif /* SKYWALK */
923 
924 	/*
925 	 * Residual data is possible in the case of IO vectors but not
926 	 * in the mbuf case since the latter is treated as atomic send.
927 	 * If pru_sosend() consumed a portion of the iovecs data and
928 	 * the error returned is transient, treat it as success; this
929 	 * is consistent with sendit() behavior.
930 	 */
931 	if (auio != NULL && uio_resid(auio) != datalen &&
932 	    (error == ERESTART || error == EINTR || error == EWOULDBLOCK)) {
933 		error = 0;
934 	}
935 
936 	if (error == 0 && sentlen != NULL) {
937 		if (auio != NULL) {
938 			*sentlen = datalen - uio_resid(auio);
939 		} else {
940 			*sentlen = datalen;
941 		}
942 	}
943 	if (auio != NULL) {
944 		uio_free(auio);
945 	}
946 
947 	return error;
948 
949 /*
950  * In cases where we detect an error before returning, we need to
951  * free the mbuf chain if there is one. sosend (and pru_sosend) will
952  * free the mbuf chain if they encounter an error.
953  */
954 errorout:
955 	if (control) {
956 		m_freem(control);
957 	}
958 	if (data) {
959 		m_freem(data);
960 	}
961 	if (sentlen) {
962 		*sentlen = 0;
963 	}
964 	if (auio != NULL) {
965 		uio_free(auio);
966 	}
967 	return error;
968 }
969 
970 errno_t
sock_send(socket_t sock,const struct msghdr * msg,int flags,size_t * sentlen)971 sock_send(socket_t sock, const struct msghdr *msg, int flags, size_t *sentlen)
972 {
973 	if (msg == NULL || msg->msg_iov == NULL || msg->msg_iovlen < 1) {
974 		return EINVAL;
975 	}
976 
977 	return sock_send_internal(sock, msg, NULL, flags, sentlen);
978 }
979 
980 errno_t
sock_sendmbuf(socket_t sock,const struct msghdr * msg,mbuf_t data,int flags,size_t * sentlen)981 sock_sendmbuf(socket_t sock, const struct msghdr *msg, mbuf_t data,
982     int flags, size_t *sentlen)
983 {
984 	int error;
985 
986 	if (data == NULL || (msg != NULL && (msg->msg_iov != NULL ||
987 	    msg->msg_iovlen != 0))) {
988 		if (data != NULL) {
989 			m_freem(data);
990 		}
991 		error = EINVAL;
992 		goto done;
993 	}
994 	error = sock_send_internal(sock, msg, data, flags, sentlen);
995 done:
996 	return error;
997 }
998 
999 errno_t
sock_sendmbuf_can_wait(socket_t sock,const struct msghdr * msg,mbuf_t data,int flags,size_t * sentlen)1000 sock_sendmbuf_can_wait(socket_t sock, const struct msghdr *msg, mbuf_t data,
1001     int flags, size_t *sentlen)
1002 {
1003 	int error;
1004 	int count = 0;
1005 	int i;
1006 	mbuf_t m;
1007 	struct msghdr msg_temp = {};
1008 
1009 	if (data == NULL || (msg != NULL && (msg->msg_iov != NULL ||
1010 	    msg->msg_iovlen != 0))) {
1011 		error = EINVAL;
1012 		goto done;
1013 	}
1014 
1015 	/*
1016 	 * Use the name and control
1017 	 */
1018 	msg_temp.msg_name = msg->msg_name;
1019 	msg_temp.msg_namelen = msg->msg_namelen;
1020 	msg_temp.msg_control = msg->msg_control;
1021 	msg_temp.msg_controllen = msg->msg_controllen;
1022 
1023 	/*
1024 	 * Count the number of mbufs in the chain
1025 	 */
1026 	for (m = data; m != NULL; m = mbuf_next(m)) {
1027 		count++;
1028 	}
1029 
1030 	struct iovec *msg_iov = kalloc_type(struct iovec, count, Z_WAITOK | Z_ZERO);
1031 	if (msg_iov == NULL) {
1032 		error = ENOMEM;
1033 		goto done;
1034 	}
1035 
1036 	msg_temp.msg_iov = msg_iov;
1037 	msg_temp.msg_iovlen = count;
1038 
1039 	for (i = 0, m = data; m != NULL; i++, m = mbuf_next(m)) {
1040 		msg_iov[i].iov_base = mtod(m, void*);
1041 		msg_iov[i].iov_len = mbuf_len(m);
1042 	}
1043 
1044 	error = sock_send_internal(sock, &msg_temp, NULL, flags, sentlen);
1045 done:
1046 	if (data != NULL) {
1047 		m_freem(data);
1048 	}
1049 	if (msg_temp.msg_iov != NULL) {
1050 		kfree_type(struct iovec, count, msg_temp.msg_iov);
1051 	}
1052 	return error;
1053 }
1054 
1055 errno_t
sock_shutdown(socket_t sock,int how)1056 sock_shutdown(socket_t sock, int how)
1057 {
1058 	if (sock == NULL) {
1059 		return EINVAL;
1060 	}
1061 
1062 	return soshutdown(sock, how);
1063 }
1064 
1065 errno_t
sock_socket_common(int domain,int type,int protocol,sock_upcall callback,void * context,socket_t * new_so,bool is_internal)1066 sock_socket_common(int domain, int type, int protocol, sock_upcall callback,
1067     void *context, socket_t *new_so, bool is_internal)
1068 {
1069 	int error = 0;
1070 
1071 	if (new_so == NULL) {
1072 		return EINVAL;
1073 	}
1074 
1075 	/* socreate will create an initial so_count */
1076 	error = socreate(domain, new_so, type, protocol);
1077 	if (error == 0) {
1078 		/*
1079 		 * This is an in-kernel socket
1080 		 */
1081 		(*new_so)->so_flags1 |= SOF1_IN_KERNEL_SOCKET;
1082 		INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_total);
1083 		if (is_internal) {
1084 			INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_os_total);
1085 		}
1086 
1087 		/* see comments in sock_setupcall() */
1088 		if (callback != NULL) {
1089 			sock_setupcall(*new_so, callback, context);
1090 		}
1091 		/*
1092 		 * last_pid and last_upid should be zero for sockets
1093 		 * created using sock_socket
1094 		 */
1095 		(*new_so)->last_pid = 0;
1096 		(*new_so)->last_upid = 0;
1097 	}
1098 	return error;
1099 }
1100 
1101 errno_t
sock_socket_internal(int domain,int type,int protocol,sock_upcall callback,void * context,socket_t * new_so)1102 sock_socket_internal(int domain, int type, int protocol, sock_upcall callback,
1103     void *context, socket_t *new_so)
1104 {
1105 	return sock_socket_common(domain, type, protocol, callback,
1106 	           context, new_so, true);
1107 }
1108 
1109 errno_t
sock_socket(int domain,int type,int protocol,sock_upcall callback,void * context,socket_t * new_so)1110 sock_socket(int domain, int type, int protocol, sock_upcall callback,
1111     void *context, socket_t *new_so)
1112 {
1113 	return sock_socket_common(domain, type, protocol, callback,
1114 	           context, new_so, false);
1115 }
1116 
1117 void
sock_close(socket_t sock)1118 sock_close(socket_t sock)
1119 {
1120 	if (sock == NULL) {
1121 		return;
1122 	}
1123 
1124 	soclose(sock);
1125 }
1126 
1127 /* Do we want this to be APPLE_PRIVATE API?: YES (LD 12/23/04) */
1128 void
sock_retain(socket_t sock)1129 sock_retain(socket_t sock)
1130 {
1131 	if (sock == NULL) {
1132 		return;
1133 	}
1134 
1135 	socket_lock(sock, 1);
1136 	sock->so_retaincnt++;
1137 	sock->so_usecount++;    /* add extra reference for holding the socket */
1138 	socket_unlock(sock, 1);
1139 }
1140 
1141 /* Do we want this to be APPLE_PRIVATE API? */
1142 void
sock_release(socket_t sock)1143 sock_release(socket_t sock)
1144 {
1145 	if (sock == NULL) {
1146 		return;
1147 	}
1148 
1149 	socket_lock(sock, 1);
1150 	if (sock->so_upcallusecount > 0) {
1151 		soclose_wait_locked(sock);
1152 	}
1153 
1154 	sock->so_retaincnt--;
1155 	if (sock->so_retaincnt < 0) {
1156 		panic("%s: negative retain count (%d) for sock=%p",
1157 		    __func__, sock->so_retaincnt, sock);
1158 		/* NOTREACHED */
1159 	}
1160 	/*
1161 	 * The so_usecount values '2' and '3' are special because they
1162 	 * indicate how many references are on the socket when it is
1163 	 * ready for closing:
1164 	 *  - there is always one use count that was just taken by this function;
1165 	 *  - '2' works for most kinds of socket as there is one use count
1166 	 *    for the socket held by the file or by the KEXT;
1167 	 *  - '3' works for connected Unix domain sockets as each peer
1168 	 *    holds a connection to the other peer.
1169 	 * Check SS_NOFDREF in case a close happened as sock_retain()
1170 	 * was grabbing the lock
1171 	 */
1172 	if ((sock->so_retaincnt == 0) &&
1173 	    ((SOCK_DOM(sock) != PF_LOCAL && sock->so_usecount == 2) ||
1174 	    (SOCK_DOM(sock) == PF_LOCAL && (sock->so_state & SS_ISCONNECTED) && sock->so_usecount == 3)) &&
1175 	    (!(sock->so_state & SS_NOFDREF) || (sock->so_flags & SOF_MP_SUBFLOW))) {
1176 		/* close socket only if the FD is not holding it */
1177 		soclose_locked(sock);
1178 	} else {
1179 		/* remove extra reference holding the socket */
1180 		VERIFY(sock->so_usecount > 1);
1181 		sock->so_usecount--;
1182 	}
1183 	socket_unlock(sock, 1);
1184 }
1185 
1186 errno_t
sock_setpriv(socket_t sock,int on)1187 sock_setpriv(socket_t sock, int on)
1188 {
1189 	if (sock == NULL) {
1190 		return EINVAL;
1191 	}
1192 
1193 	socket_lock(sock, 1);
1194 	if (on) {
1195 		sock->so_state |= SS_PRIV;
1196 	} else {
1197 		sock->so_state &= ~SS_PRIV;
1198 	}
1199 	socket_unlock(sock, 1);
1200 	return 0;
1201 }
1202 
1203 int
sock_isconnected(socket_t sock)1204 sock_isconnected(socket_t sock)
1205 {
1206 	int retval;
1207 
1208 	socket_lock(sock, 1);
1209 	retval = ((sock->so_state & SS_ISCONNECTED) ? 1 : 0);
1210 	socket_unlock(sock, 1);
1211 	return retval;
1212 }
1213 
1214 int
sock_isnonblocking(socket_t sock)1215 sock_isnonblocking(socket_t sock)
1216 {
1217 	int retval;
1218 
1219 	socket_lock(sock, 1);
1220 	retval = ((sock->so_state & SS_NBIO) ? 1 : 0);
1221 	socket_unlock(sock, 1);
1222 	return retval;
1223 }
1224 
1225 errno_t
sock_gettype(socket_t sock,int * outDomain,int * outType,int * outProtocol)1226 sock_gettype(socket_t sock, int *outDomain, int *outType, int *outProtocol)
1227 {
1228 	socket_lock(sock, 1);
1229 	if (outDomain != NULL) {
1230 		*outDomain = SOCK_DOM(sock);
1231 	}
1232 	if (outType != NULL) {
1233 		*outType = sock->so_type;
1234 	}
1235 	if (outProtocol != NULL) {
1236 		*outProtocol = SOCK_PROTO(sock);
1237 	}
1238 	socket_unlock(sock, 1);
1239 	return 0;
1240 }
1241 
1242 /*
1243  * Return the listening socket of a pre-accepted socket.  It returns the
1244  * listener (so_head) value of a given socket.  This is intended to be
1245  * called by a socket filter during a filter attach (sf_attach) callback.
1246  * The value returned by this routine is safe to be used only in the
1247  * context of that callback, because we hold the listener's lock across
1248  * the sflt_initsock() call.
1249  */
1250 socket_t
sock_getlistener(socket_t sock)1251 sock_getlistener(socket_t sock)
1252 {
1253 	return sock->so_head;
1254 }
1255 
1256 static inline void
sock_set_tcp_stream_priority(socket_t sock)1257 sock_set_tcp_stream_priority(socket_t sock)
1258 {
1259 	if ((SOCK_DOM(sock) == PF_INET || SOCK_DOM(sock) == PF_INET6) &&
1260 	    SOCK_TYPE(sock) == SOCK_STREAM) {
1261 		set_tcp_stream_priority(sock);
1262 	}
1263 }
1264 
1265 /*
1266  * Caller must have ensured socket is valid and won't be going away.
1267  */
1268 void
socket_set_traffic_mgt_flags_locked(socket_t sock,u_int8_t flags)1269 socket_set_traffic_mgt_flags_locked(socket_t sock, u_int8_t flags)
1270 {
1271 	u_int32_t soflags1 = 0;
1272 
1273 	if ((flags & TRAFFIC_MGT_SO_BACKGROUND)) {
1274 		soflags1 |= SOF1_TRAFFIC_MGT_SO_BACKGROUND;
1275 	}
1276 	if ((flags & TRAFFIC_MGT_TCP_RECVBG)) {
1277 		soflags1 |= SOF1_TRAFFIC_MGT_TCP_RECVBG;
1278 	}
1279 
1280 	(void) OSBitOrAtomic(soflags1, &sock->so_flags1);
1281 
1282 	sock_set_tcp_stream_priority(sock);
1283 }
1284 
1285 void
socket_set_traffic_mgt_flags(socket_t sock,u_int8_t flags)1286 socket_set_traffic_mgt_flags(socket_t sock, u_int8_t flags)
1287 {
1288 	socket_lock(sock, 1);
1289 	socket_set_traffic_mgt_flags_locked(sock, flags);
1290 	socket_unlock(sock, 1);
1291 }
1292 
1293 /*
1294  * Caller must have ensured socket is valid and won't be going away.
1295  */
1296 void
socket_clear_traffic_mgt_flags_locked(socket_t sock,u_int8_t flags)1297 socket_clear_traffic_mgt_flags_locked(socket_t sock, u_int8_t flags)
1298 {
1299 	u_int32_t soflags1 = 0;
1300 
1301 	if ((flags & TRAFFIC_MGT_SO_BACKGROUND)) {
1302 		soflags1 |= SOF1_TRAFFIC_MGT_SO_BACKGROUND;
1303 	}
1304 	if ((flags & TRAFFIC_MGT_TCP_RECVBG)) {
1305 		soflags1 |= SOF1_TRAFFIC_MGT_TCP_RECVBG;
1306 	}
1307 
1308 	(void) OSBitAndAtomic(~soflags1, &sock->so_flags1);
1309 
1310 	sock_set_tcp_stream_priority(sock);
1311 }
1312 
1313 void
socket_clear_traffic_mgt_flags(socket_t sock,u_int8_t flags)1314 socket_clear_traffic_mgt_flags(socket_t sock, u_int8_t flags)
1315 {
1316 	socket_lock(sock, 1);
1317 	socket_clear_traffic_mgt_flags_locked(sock, flags);
1318 	socket_unlock(sock, 1);
1319 }
1320 
1321 
1322 /*
1323  * Caller must have ensured socket is valid and won't be going away.
1324  */
1325 errno_t
socket_defunct(struct proc * p,socket_t so,int level)1326 socket_defunct(struct proc *p, socket_t so, int level)
1327 {
1328 	errno_t retval;
1329 
1330 	if (level != SHUTDOWN_SOCKET_LEVEL_DISCONNECT_SVC &&
1331 	    level != SHUTDOWN_SOCKET_LEVEL_DISCONNECT_ALL) {
1332 		return EINVAL;
1333 	}
1334 
1335 	socket_lock(so, 1);
1336 	/*
1337 	 * SHUTDOWN_SOCKET_LEVEL_DISCONNECT_SVC level is meant to tear down
1338 	 * all of mDNSResponder IPC sockets, currently those of AF_UNIX; note
1339 	 * that this is an implementation artifact of mDNSResponder.  We do
1340 	 * a quick test against the socket buffers for SB_UNIX, since that
1341 	 * would have been set by unp_attach() at socket creation time.
1342 	 */
1343 	if (level == SHUTDOWN_SOCKET_LEVEL_DISCONNECT_SVC &&
1344 	    (so->so_rcv.sb_flags & so->so_snd.sb_flags & SB_UNIX) != SB_UNIX) {
1345 		socket_unlock(so, 1);
1346 		return EOPNOTSUPP;
1347 	}
1348 	retval = sosetdefunct(p, so, level, TRUE);
1349 	if (retval == 0) {
1350 		retval = sodefunct(p, so, level);
1351 	}
1352 	socket_unlock(so, 1);
1353 	return retval;
1354 }
1355 
1356 void
sock_setupcalls_locked(socket_t sock,sock_upcall rcallback,void * rcontext,sock_upcall wcallback,void * wcontext,int locked)1357 sock_setupcalls_locked(socket_t sock, sock_upcall rcallback, void *rcontext,
1358     sock_upcall wcallback, void *wcontext, int locked)
1359 {
1360 	if (rcallback != NULL) {
1361 		sock->so_rcv.sb_flags |= SB_UPCALL;
1362 		if (locked) {
1363 			sock->so_rcv.sb_flags |= SB_UPCALL_LOCK;
1364 		}
1365 		sock->so_rcv.sb_upcall = rcallback;
1366 		sock->so_rcv.sb_upcallarg = rcontext;
1367 	} else {
1368 		sock->so_rcv.sb_flags &= ~(SB_UPCALL | SB_UPCALL_LOCK);
1369 		sock->so_rcv.sb_upcall = NULL;
1370 		sock->so_rcv.sb_upcallarg = NULL;
1371 	}
1372 
1373 	if (wcallback != NULL) {
1374 		sock->so_snd.sb_flags |= SB_UPCALL;
1375 		if (locked) {
1376 			sock->so_snd.sb_flags |= SB_UPCALL_LOCK;
1377 		}
1378 		sock->so_snd.sb_upcall = wcallback;
1379 		sock->so_snd.sb_upcallarg = wcontext;
1380 	} else {
1381 		sock->so_snd.sb_flags &= ~(SB_UPCALL | SB_UPCALL_LOCK);
1382 		sock->so_snd.sb_upcall = NULL;
1383 		sock->so_snd.sb_upcallarg = NULL;
1384 	}
1385 }
1386 
1387 errno_t
sock_setupcall(socket_t sock,sock_upcall callback,void * context)1388 sock_setupcall(socket_t sock, sock_upcall callback, void *context)
1389 {
1390 	if (sock == NULL) {
1391 		return EINVAL;
1392 	}
1393 
1394 	/*
1395 	 * Note that we don't wait for any in progress upcall to complete.
1396 	 * On embedded, sock_setupcall() causes both read and write
1397 	 * callbacks to be set; on desktop, only read callback is set
1398 	 * to maintain legacy KPI behavior.
1399 	 *
1400 	 * The newer sock_setupcalls() KPI should be used instead to set
1401 	 * the read and write callbacks and their respective parameters.
1402 	 */
1403 	socket_lock(sock, 1);
1404 #if defined(__arm64__)
1405 	sock_setupcalls_locked(sock, callback, context, callback, context, 0);
1406 #else /* defined(__arm64__) */
1407 	sock_setupcalls_locked(sock, callback, context, NULL, NULL, 0);
1408 #endif /* defined(__arm64__) */
1409 	socket_unlock(sock, 1);
1410 
1411 	return 0;
1412 }
1413 
1414 errno_t
sock_setupcalls(socket_t sock,sock_upcall rcallback,void * rcontext,sock_upcall wcallback,void * wcontext)1415 sock_setupcalls(socket_t sock, sock_upcall rcallback, void *rcontext,
1416     sock_upcall wcallback, void *wcontext)
1417 {
1418 	if (sock == NULL) {
1419 		return EINVAL;
1420 	}
1421 
1422 	/*
1423 	 * Note that we don't wait for any in progress upcall to complete.
1424 	 */
1425 	socket_lock(sock, 1);
1426 	sock_setupcalls_locked(sock, rcallback, rcontext, wcallback, wcontext, 0);
1427 	socket_unlock(sock, 1);
1428 
1429 	return 0;
1430 }
1431 
1432 void
sock_catchevents_locked(socket_t sock,sock_evupcall ecallback,void * econtext,uint32_t emask)1433 sock_catchevents_locked(socket_t sock, sock_evupcall ecallback, void *econtext,
1434     uint32_t emask)
1435 {
1436 	socket_lock_assert_owned(sock);
1437 
1438 	/*
1439 	 * Note that we don't wait for any in progress upcall to complete.
1440 	 */
1441 	if (ecallback != NULL) {
1442 		sock->so_event = ecallback;
1443 		sock->so_eventarg = econtext;
1444 		sock->so_eventmask = emask;
1445 	} else {
1446 		sock->so_event = sonullevent;
1447 		sock->so_eventarg = NULL;
1448 		sock->so_eventmask = 0;
1449 	}
1450 }
1451 
1452 errno_t
sock_catchevents(socket_t sock,sock_evupcall ecallback,void * econtext,uint32_t emask)1453 sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
1454     uint32_t emask)
1455 {
1456 	if (sock == NULL) {
1457 		return EINVAL;
1458 	}
1459 
1460 	socket_lock(sock, 1);
1461 	sock_catchevents_locked(sock, ecallback, econtext, emask);
1462 	socket_unlock(sock, 1);
1463 
1464 	return 0;
1465 }
1466 
1467 /*
1468  * Returns true whether or not a socket belongs to the kernel.
1469  */
1470 int
sock_iskernel(socket_t so)1471 sock_iskernel(socket_t so)
1472 {
1473 	return so && so->last_pid == 0;
1474 }
1475