xref: /xnu-10002.81.5/bsd/kern/vsock_domain.c (revision 5e3eaea39dcf651e66cb99ba7d70e32cc4a99587)
1 /*
2  * Copyright (c) 2020 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <sys/domain.h>
30 #include <sys/socket.h>
31 #include <sys/protosw.h>
32 #include <sys/mcache.h>
33 #include <sys/systm.h>
34 #include <sys/sysctl.h>
35 #include <sys/random.h>
36 #include <sys/mbuf.h>
37 #include <sys/vsock_domain.h>
38 #include <sys/vsock_transport.h>
39 #include <kern/task.h>
40 #include <kern/zalloc.h>
41 #include <kern/locks.h>
42 #include <machine/atomic.h>
43 
44 #define sotovsockpcb(so) ((struct vsockpcb *)(so)->so_pcb)
45 
46 #define VSOCK_PORT_RESERVED 1024
47 
48 /* VSock Protocol Globals */
49 
50 static struct vsock_transport * _Atomic the_vsock_transport = NULL;
51 static ZONE_DEFINE(vsockpcb_zone, "vsockpcbzone",
52     sizeof(struct vsockpcb), ZC_NONE);
53 static LCK_GRP_DECLARE(vsock_lock_grp, "vsock");
54 static struct vsockpcbinfo vsockinfo;
55 
56 static uint32_t vsock_sendspace = VSOCK_MAX_PACKET_SIZE * 8;
57 static uint32_t vsock_recvspace = VSOCK_MAX_PACKET_SIZE * 8;
58 
59 /* VSock PCB Helpers */
60 
61 static uint32_t
vsock_get_peer_space(struct vsockpcb * pcb)62 vsock_get_peer_space(struct vsockpcb *pcb)
63 {
64 	return pcb->peer_buf_alloc - (pcb->tx_cnt - pcb->peer_fwd_cnt);
65 }
66 
67 static struct vsockpcb *
vsock_get_matching_pcb(struct vsock_address src,struct vsock_address dst)68 vsock_get_matching_pcb(struct vsock_address src, struct vsock_address dst)
69 {
70 	struct vsockpcb *preferred = NULL;
71 	struct vsockpcb *match = NULL;
72 	struct vsockpcb *pcb = NULL;
73 
74 	lck_rw_lock_shared(&vsockinfo.bound_lock);
75 	LIST_FOREACH(pcb, &vsockinfo.bound, bound) {
76 		// Source cid and port must match. Only destination port must match. (Allows for a changing CID during migration)
77 		socket_lock(pcb->so, 1);
78 		if ((pcb->so->so_state & SS_ISCONNECTED || pcb->so->so_state & SS_ISCONNECTING) &&
79 		    pcb->local_address.cid == src.cid && pcb->local_address.port == src.port &&
80 		    pcb->remote_address.port == dst.port) {
81 			preferred = pcb;
82 			break;
83 		} else if ((pcb->local_address.cid == src.cid || pcb->local_address.cid == VMADDR_CID_ANY) &&
84 		    pcb->local_address.port == src.port) {
85 			match = pcb;
86 		}
87 		socket_unlock(pcb->so, 1);
88 	}
89 	if (!preferred && match) {
90 		socket_lock(match->so, 1);
91 		preferred = match;
92 	}
93 	lck_rw_done(&vsockinfo.bound_lock);
94 
95 	return preferred;
96 }
97 
98 static errno_t
vsock_bind_address_if_free(struct vsockpcb * pcb,uint32_t local_cid,uint32_t local_port,uint32_t remote_cid,uint32_t remote_port)99 vsock_bind_address_if_free(struct vsockpcb *pcb, uint32_t local_cid, uint32_t local_port, uint32_t remote_cid, uint32_t remote_port)
100 {
101 	socket_lock_assert_owned(pcb->so);
102 
103 	// Privileged ports.
104 	if (local_port != VMADDR_PORT_ANY && local_port < VSOCK_PORT_RESERVED &&
105 	    current_task() != kernel_task && proc_suser(current_proc()) != 0) {
106 		return EACCES;
107 	}
108 
109 	bool taken = false;
110 	const bool check_remote = (remote_cid != VMADDR_CID_ANY && remote_port != VMADDR_PORT_ANY);
111 
112 	struct vsockpcb *pcb_match = NULL;
113 
114 	socket_unlock(pcb->so, 0);
115 	lck_rw_lock_exclusive(&vsockinfo.bound_lock);
116 	LIST_FOREACH(pcb_match, &vsockinfo.bound, bound) {
117 		socket_lock(pcb_match->so, 1);
118 		if (pcb == pcb_match ||
119 		    (!check_remote && pcb_match->local_address.port == local_port) ||
120 		    (check_remote && pcb_match->local_address.port == local_port &&
121 		    pcb_match->remote_address.cid == remote_cid && pcb_match->remote_address.port == remote_port)) {
122 			socket_unlock(pcb_match->so, 1);
123 			taken = true;
124 			break;
125 		}
126 		socket_unlock(pcb_match->so, 1);
127 	}
128 	socket_lock(pcb->so, 0);
129 	if (!taken) {
130 		pcb->local_address = (struct vsock_address) { .cid = local_cid, .port = local_port };
131 		pcb->remote_address = (struct vsock_address) { .cid = remote_cid, .port = remote_port };
132 		LIST_INSERT_HEAD(&vsockinfo.bound, pcb, bound);
133 	}
134 	lck_rw_done(&vsockinfo.bound_lock);
135 
136 	return taken ? EADDRINUSE : 0;
137 }
138 
139 static errno_t
vsock_bind_address(struct vsockpcb * pcb,struct vsock_address laddr,struct vsock_address raddr)140 vsock_bind_address(struct vsockpcb *pcb, struct vsock_address laddr, struct vsock_address raddr)
141 {
142 	if (!pcb) {
143 		return EINVAL;
144 	}
145 
146 	socket_lock_assert_owned(pcb->so);
147 
148 	// Certain CIDs are reserved.
149 	if (laddr.cid == VMADDR_CID_HYPERVISOR || laddr.cid == VMADDR_CID_RESERVED || laddr.cid == VMADDR_CID_HOST) {
150 		return EADDRNOTAVAIL;
151 	}
152 
153 	// Remote address must be fully specified or not specified at all.
154 	if ((raddr.cid == VMADDR_CID_ANY) ^ (raddr.port == VMADDR_PORT_ANY)) {
155 		return EINVAL;
156 	}
157 
158 	// Cannot bind if already bound.
159 	if (pcb->local_address.port != VMADDR_PORT_ANY) {
160 		return EINVAL;
161 	}
162 
163 	uint32_t transport_cid;
164 	struct vsock_transport *transport = pcb->transport;
165 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
166 	if (error) {
167 		return error;
168 	}
169 
170 	// Local CID must be this transport's CID or any.
171 	if (laddr.cid != transport_cid && laddr.cid != VMADDR_CID_ANY) {
172 		return EINVAL;
173 	}
174 
175 	if (laddr.port != VMADDR_PORT_ANY) {
176 		error = vsock_bind_address_if_free(pcb, laddr.cid, laddr.port, raddr.cid, raddr.port);
177 	} else {
178 		lck_mtx_lock(&vsockinfo.port_lock);
179 
180 		const uint32_t first = VSOCK_PORT_RESERVED;
181 		const uint32_t last = VMADDR_PORT_ANY - 1;
182 		uint32_t count = last - first + 1;
183 		uint32_t *last_port = &vsockinfo.last_port;
184 
185 		if (pcb->so->so_flags & SOF_BINDRANDOMPORT) {
186 			uint32_t random = 0;
187 			read_frandom(&random, sizeof(random));
188 			*last_port = first + (random % count);
189 		}
190 
191 		do {
192 			if (count == 0) {
193 				lck_mtx_unlock(&vsockinfo.port_lock);
194 				return EADDRNOTAVAIL;
195 			}
196 			count--;
197 
198 			++*last_port;
199 			if (*last_port < first || *last_port > last) {
200 				*last_port = first;
201 			}
202 
203 			error = vsock_bind_address_if_free(pcb, laddr.cid, *last_port, raddr.cid, raddr.port);
204 		} while (error);
205 
206 		lck_mtx_unlock(&vsockinfo.port_lock);
207 	}
208 
209 	return error;
210 }
211 
212 static void
vsock_unbind_pcb(struct vsockpcb * pcb,bool is_locked)213 vsock_unbind_pcb(struct vsockpcb *pcb, bool is_locked)
214 {
215 	if (!pcb) {
216 		return;
217 	}
218 
219 	socket_lock_assert_owned(pcb->so);
220 
221 	soisdisconnected(pcb->so);
222 
223 	if (!pcb->bound.le_prev) {
224 		return;
225 	}
226 
227 	if (!is_locked) {
228 		socket_unlock(pcb->so, 0);
229 		lck_rw_lock_exclusive(&vsockinfo.bound_lock);
230 		socket_lock(pcb->so, 0);
231 		if (!pcb->bound.le_prev) {
232 			lck_rw_done(&vsockinfo.bound_lock);
233 			return;
234 		}
235 	}
236 
237 	LIST_REMOVE(pcb, bound);
238 	pcb->bound.le_next = NULL;
239 	pcb->bound.le_prev = NULL;
240 
241 	if (!is_locked) {
242 		lck_rw_done(&vsockinfo.bound_lock);
243 	}
244 }
245 
246 static struct sockaddr *
vsock_new_sockaddr(struct vsock_address * address)247 vsock_new_sockaddr(struct vsock_address *address)
248 {
249 	if (!address) {
250 		return NULL;
251 	}
252 
253 	struct sockaddr_vm *addr;
254 	addr = (struct sockaddr_vm *)alloc_sockaddr(sizeof(*addr),
255 	    Z_WAITOK | Z_NOFAIL);
256 
257 	addr->svm_family = AF_VSOCK;
258 	addr->svm_port = address->port;
259 	addr->svm_cid = address->cid;
260 
261 	return (struct sockaddr *)addr;
262 }
263 
264 static errno_t
vsock_pcb_send_message(struct vsockpcb * pcb,enum vsock_operation operation,mbuf_t m)265 vsock_pcb_send_message(struct vsockpcb *pcb, enum vsock_operation operation, mbuf_t m)
266 {
267 	if (!pcb) {
268 		if (m != NULL) {
269 			mbuf_freem_list(m);
270 		}
271 		return EINVAL;
272 	}
273 
274 	socket_lock_assert_owned(pcb->so);
275 
276 	errno_t error;
277 
278 	struct vsock_address dst = pcb->remote_address;
279 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
280 		if (m != NULL) {
281 			mbuf_freem_list(m);
282 		}
283 		return EINVAL;
284 	}
285 
286 	struct vsock_address src = pcb->local_address;
287 	if (src.cid == VMADDR_CID_ANY) {
288 		uint32_t transport_cid;
289 		struct vsock_transport *transport = pcb->transport;
290 		error = transport->get_cid(transport->provider, &transport_cid);
291 		if (error) {
292 			if (m != NULL) {
293 				mbuf_freem_list(m);
294 			}
295 			return error;
296 		}
297 		src.cid = transport_cid;
298 	}
299 
300 	uint32_t buf_alloc = pcb->so->so_rcv.sb_hiwat;
301 	uint32_t fwd_cnt = pcb->fwd_cnt;
302 
303 	if (src.cid == dst.cid) {
304 		pcb->last_buf_alloc = buf_alloc;
305 		pcb->last_fwd_cnt = fwd_cnt;
306 
307 		socket_unlock(pcb->so, 0);
308 		error = vsock_put_message(src, dst, operation, buf_alloc, fwd_cnt, m);
309 		socket_lock(pcb->so, 0);
310 	} else {
311 		struct vsock_transport *transport = pcb->transport;
312 		error = transport->put_message(transport->provider, src, dst, operation, buf_alloc, fwd_cnt, m);
313 
314 		if (!error) {
315 			pcb->last_buf_alloc = buf_alloc;
316 			pcb->last_fwd_cnt = fwd_cnt;
317 		}
318 	}
319 
320 	return error;
321 }
322 
323 static errno_t
vsock_pcb_reset_address(struct vsock_address src,struct vsock_address dst)324 vsock_pcb_reset_address(struct vsock_address src, struct vsock_address dst)
325 {
326 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
327 		return EINVAL;
328 	}
329 
330 	errno_t error;
331 	struct vsock_transport *transport = NULL;
332 
333 	if (src.cid == VMADDR_CID_ANY) {
334 		transport = os_atomic_load(&the_vsock_transport, relaxed);
335 		if (transport == NULL) {
336 			return ENODEV;
337 		}
338 
339 		uint32_t transport_cid;
340 		error = transport->get_cid(transport->provider, &transport_cid);
341 		if (error) {
342 			return error;
343 		}
344 		src.cid = transport_cid;
345 	}
346 
347 	if (src.cid == dst.cid) {
348 		error = vsock_put_message(src, dst, VSOCK_RESET, 0, 0, NULL);
349 	} else {
350 		if (!transport) {
351 			transport = os_atomic_load(&the_vsock_transport, relaxed);
352 			if (transport == NULL) {
353 				return ENODEV;
354 			}
355 		}
356 		error = transport->put_message(transport->provider, src, dst, VSOCK_RESET, 0, 0, NULL);
357 	}
358 
359 	return error;
360 }
361 
362 static errno_t
vsock_pcb_safe_reset_address(struct vsockpcb * pcb,struct vsock_address src,struct vsock_address dst)363 vsock_pcb_safe_reset_address(struct vsockpcb *pcb, struct vsock_address src, struct vsock_address dst)
364 {
365 	if (pcb) {
366 		socket_lock_assert_owned(pcb->so);
367 		socket_unlock(pcb->so, 0);
368 	}
369 	errno_t error = vsock_pcb_reset_address(src, dst);
370 	if (pcb) {
371 		socket_lock(pcb->so, 0);
372 	}
373 	return error;
374 }
375 
376 static errno_t
vsock_pcb_connect(struct vsockpcb * pcb)377 vsock_pcb_connect(struct vsockpcb *pcb)
378 {
379 	return vsock_pcb_send_message(pcb, VSOCK_REQUEST, NULL);
380 }
381 
382 static errno_t
vsock_pcb_respond(struct vsockpcb * pcb)383 vsock_pcb_respond(struct vsockpcb *pcb)
384 {
385 	return vsock_pcb_send_message(pcb, VSOCK_RESPONSE, NULL);
386 }
387 
388 static errno_t
vsock_pcb_send(struct vsockpcb * pcb,mbuf_t m)389 vsock_pcb_send(struct vsockpcb *pcb, mbuf_t m)
390 {
391 	return vsock_pcb_send_message(pcb, VSOCK_PAYLOAD, m);
392 }
393 
394 static errno_t
vsock_pcb_shutdown_send(struct vsockpcb * pcb)395 vsock_pcb_shutdown_send(struct vsockpcb *pcb)
396 {
397 	return vsock_pcb_send_message(pcb, VSOCK_SHUTDOWN_SEND, NULL);
398 }
399 
400 static errno_t
vsock_pcb_reset(struct vsockpcb * pcb)401 vsock_pcb_reset(struct vsockpcb *pcb)
402 {
403 	return vsock_pcb_send_message(pcb, VSOCK_RESET, NULL);
404 }
405 
406 static errno_t
vsock_pcb_credit_update(struct vsockpcb * pcb)407 vsock_pcb_credit_update(struct vsockpcb *pcb)
408 {
409 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_UPDATE, NULL);
410 }
411 
412 static errno_t
vsock_pcb_credit_request(struct vsockpcb * pcb)413 vsock_pcb_credit_request(struct vsockpcb *pcb)
414 {
415 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_REQUEST, NULL);
416 }
417 
418 static errno_t
vsock_disconnect_pcb_common(struct vsockpcb * pcb,bool is_locked)419 vsock_disconnect_pcb_common(struct vsockpcb *pcb, bool is_locked)
420 {
421 	socket_lock_assert_owned(pcb->so);
422 	vsock_unbind_pcb(pcb, is_locked);
423 	return vsock_pcb_reset(pcb);
424 }
425 
426 static errno_t
vsock_disconnect_pcb_locked(struct vsockpcb * pcb)427 vsock_disconnect_pcb_locked(struct vsockpcb *pcb)
428 {
429 	return vsock_disconnect_pcb_common(pcb, true);
430 }
431 
432 static errno_t
vsock_disconnect_pcb(struct vsockpcb * pcb)433 vsock_disconnect_pcb(struct vsockpcb *pcb)
434 {
435 	return vsock_disconnect_pcb_common(pcb, false);
436 }
437 
438 static errno_t
vsock_sockaddr_vm_validate(struct vsockpcb * pcb,struct sockaddr_vm * addr)439 vsock_sockaddr_vm_validate(struct vsockpcb *pcb, struct sockaddr_vm *addr)
440 {
441 	if (!pcb || !pcb->so || !addr) {
442 		return EINVAL;
443 	}
444 
445 	// Validate address length.
446 	if (addr->svm_len < sizeof(struct sockaddr_vm)) {
447 		return EINVAL;
448 	}
449 
450 	// Validate address family.
451 	if (addr->svm_family != AF_UNSPEC && addr->svm_family != AF_VSOCK) {
452 		return EAFNOSUPPORT;
453 	}
454 
455 	// Only stream is supported currently.
456 	if (pcb->so->so_type != SOCK_STREAM) {
457 		return EAFNOSUPPORT;
458 	}
459 
460 	return 0;
461 }
462 /* VSock Receive Handlers */
463 
464 static errno_t
vsock_put_message_connected(struct vsockpcb * pcb,enum vsock_operation op,mbuf_t m)465 vsock_put_message_connected(struct vsockpcb *pcb, enum vsock_operation op, mbuf_t m)
466 {
467 	socket_lock_assert_owned(pcb->so);
468 
469 	errno_t error = 0;
470 
471 	switch (op) {
472 	case VSOCK_SHUTDOWN:
473 		socantsendmore(pcb->so);
474 		socantrcvmore(pcb->so);
475 		break;
476 	case VSOCK_SHUTDOWN_RECEIVE:
477 		socantsendmore(pcb->so);
478 		break;
479 	case VSOCK_SHUTDOWN_SEND:
480 		socantrcvmore(pcb->so);
481 		break;
482 	case VSOCK_PAYLOAD:
483 		// Add data to the receive queue then wakeup any reading threads.
484 		error = !sbappendstream(&pcb->so->so_rcv, m);
485 		if (!error) {
486 			sorwakeup(pcb->so);
487 		}
488 		break;
489 	case VSOCK_RESET:
490 		vsock_unbind_pcb(pcb, false);
491 		break;
492 	default:
493 		error = ENOTSUP;
494 		break;
495 	}
496 
497 	return error;
498 }
499 
500 static errno_t
vsock_put_message_connecting(struct vsockpcb * pcb,enum vsock_operation op)501 vsock_put_message_connecting(struct vsockpcb *pcb, enum vsock_operation op)
502 {
503 	socket_lock_assert_owned(pcb->so);
504 
505 	errno_t error = 0;
506 
507 	switch (op) {
508 	case VSOCK_RESPONSE:
509 		soisconnected(pcb->so);
510 		break;
511 	case VSOCK_RESET:
512 		pcb->so->so_error = EAGAIN;
513 		error = vsock_disconnect_pcb(pcb);
514 		break;
515 	default:
516 		vsock_disconnect_pcb(pcb);
517 		error = ENOTSUP;
518 		break;
519 	}
520 
521 	return error;
522 }
523 
524 static errno_t
vsock_put_message_listening(struct vsockpcb * pcb,enum vsock_operation op,struct vsock_address src,struct vsock_address dst)525 vsock_put_message_listening(struct vsockpcb *pcb, enum vsock_operation op, struct vsock_address src, struct vsock_address dst)
526 {
527 	socket_lock_assert_owned(pcb->so);
528 
529 	struct sockaddr_vm addr;
530 	struct socket *so2 = NULL;
531 	struct vsockpcb *pcb2 = NULL;
532 
533 	errno_t error = 0;
534 
535 	switch (op) {
536 	case VSOCK_REQUEST:
537 		addr = (struct sockaddr_vm) {
538 			.svm_len = sizeof(addr),
539 			.svm_family = AF_VSOCK,
540 			.svm_reserved1 = 0,
541 			.svm_port = pcb->local_address.port,
542 			.svm_cid = pcb->local_address.cid
543 		};
544 		so2 = sonewconn(pcb->so, 0, (struct sockaddr *)&addr);
545 		if (!so2) {
546 			// It is likely that the backlog is full. Deny this request.
547 			vsock_pcb_safe_reset_address(pcb, dst, src);
548 			error = ECONNREFUSED;
549 			break;
550 		}
551 
552 		pcb2 = sotovsockpcb(so2);
553 		if (!pcb2) {
554 			error = EINVAL;
555 			goto done;
556 		}
557 
558 		error = vsock_bind_address(pcb2, dst, src);
559 		if (error) {
560 			goto done;
561 		}
562 
563 		error = vsock_pcb_respond(pcb2);
564 		if (error) {
565 			goto done;
566 		}
567 
568 		soisconnected(so2);
569 
570 done:
571 		if (error) {
572 			soisdisconnected(so2);
573 			if (pcb2) {
574 				vsock_unbind_pcb(pcb2, false);
575 			}
576 			socket_unlock(so2, 1);
577 			vsock_pcb_reset_address(dst, src);
578 		} else {
579 			socket_unlock(so2, 0);
580 		}
581 		socket_lock(pcb->so, 0);
582 
583 		break;
584 	case VSOCK_RESET:
585 		error = vsock_pcb_safe_reset_address(pcb, dst, src);
586 		break;
587 	default:
588 		vsock_pcb_safe_reset_address(pcb, dst, src);
589 		error = ENOTSUP;
590 		break;
591 	}
592 
593 	return error;
594 }
595 
596 /* VSock Transport */
597 
598 errno_t
vsock_add_transport(struct vsock_transport * transport)599 vsock_add_transport(struct vsock_transport *transport)
600 {
601 	if (transport == NULL || transport->provider == NULL) {
602 		return EINVAL;
603 	}
604 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport, NULL, transport, acq_rel)) {
605 		return EEXIST;
606 	}
607 	return 0;
608 }
609 
610 errno_t
vsock_remove_transport(struct vsock_transport * transport)611 vsock_remove_transport(struct vsock_transport *transport)
612 {
613 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport, transport, NULL, acq_rel)) {
614 		return ENODEV;
615 	}
616 	return 0;
617 }
618 
619 errno_t
vsock_reset_transport(struct vsock_transport * transport)620 vsock_reset_transport(struct vsock_transport *transport)
621 {
622 	if (transport == NULL) {
623 		return EINVAL;
624 	}
625 
626 	errno_t error = 0;
627 	struct vsockpcb *pcb = NULL;
628 	struct vsockpcb *tmp_pcb = NULL;
629 
630 	lck_rw_lock_exclusive(&vsockinfo.bound_lock);
631 	LIST_FOREACH_SAFE(pcb, &vsockinfo.bound, bound, tmp_pcb) {
632 		// Disconnect this transport's sockets. Listen and bind sockets must stay alive.
633 		socket_lock(pcb->so, 1);
634 		if (pcb->transport == transport && pcb->so->so_state & (SS_ISCONNECTED | SS_ISCONNECTING | SS_ISDISCONNECTING)) {
635 			errno_t dc_error = vsock_disconnect_pcb_locked(pcb);
636 			if (dc_error && !error) {
637 				error = dc_error;
638 			}
639 		}
640 		socket_unlock(pcb->so, 1);
641 	}
642 	lck_rw_done(&vsockinfo.bound_lock);
643 
644 	return error;
645 }
646 
647 errno_t
vsock_put_message(struct vsock_address src,struct vsock_address dst,enum vsock_operation op,uint32_t buf_alloc,uint32_t fwd_cnt,mbuf_t m)648 vsock_put_message(struct vsock_address src, struct vsock_address dst, enum vsock_operation op, uint32_t buf_alloc, uint32_t fwd_cnt, mbuf_t m)
649 {
650 	struct vsockpcb *pcb = vsock_get_matching_pcb(dst, src);
651 	if (!pcb) {
652 		if (op != VSOCK_RESET) {
653 			vsock_pcb_reset_address(dst, src);
654 		}
655 		if (m != NULL) {
656 			mbuf_freem_list(m);
657 		}
658 		return EINVAL;
659 	}
660 
661 	socket_lock_assert_owned(pcb->so);
662 
663 	struct socket *so = pcb->so;
664 	errno_t error = 0;
665 
666 	// Check if the peer's buffer has changed. Update our view of the peer's forwarded bytes.
667 	int buffers_changed = (pcb->peer_buf_alloc != buf_alloc) || (pcb->peer_fwd_cnt) != fwd_cnt;
668 	pcb->peer_buf_alloc = buf_alloc;
669 	pcb->peer_fwd_cnt = fwd_cnt;
670 
671 	// Peer's buffer has enough space for the next packet. Notify any threads waiting for space.
672 	if (buffers_changed && vsock_get_peer_space(pcb) >= pcb->waiting_send_size) {
673 		sowwakeup(so);
674 	}
675 
676 	switch (op) {
677 	case VSOCK_CREDIT_REQUEST:
678 		error = vsock_pcb_credit_update(pcb);
679 		break;
680 	case VSOCK_CREDIT_UPDATE:
681 		break;
682 	default:
683 		if (so->so_state & SS_ISCONNECTED) {
684 			error = vsock_put_message_connected(pcb, op, m);
685 			m = NULL;
686 		} else if (so->so_state & SS_ISCONNECTING) {
687 			error = vsock_put_message_connecting(pcb, op);
688 		} else if (so->so_options & SO_ACCEPTCONN) {
689 			error = vsock_put_message_listening(pcb, op, src, dst);
690 		} else {
691 			// Reset the connection for other states such as 'disconnecting'.
692 			error = vsock_disconnect_pcb(pcb);
693 			if (!error) {
694 				error = ENODEV;
695 			}
696 		}
697 		break;
698 	}
699 	socket_unlock(so, 1);
700 
701 	if (m != NULL) {
702 		mbuf_freem_list(m);
703 	}
704 
705 	return error;
706 }
707 
708 /* VSock Sysctl */
709 
710 static int
711 vsock_pcblist SYSCTL_HANDLER_ARGS
712 {
713 #pragma unused(oidp,arg2)
714 
715 	int error;
716 
717 	// Only stream is supported.
718 	if ((intptr_t)arg1 != SOCK_STREAM) {
719 		return EINVAL;
720 	}
721 
722 	// Get the generation count and the count of all vsock sockets.
723 	lck_rw_lock_shared(&vsockinfo.all_lock);
724 	uint64_t n = vsockinfo.all_pcb_count;
725 	vsock_gen_t gen_count = vsockinfo.vsock_gencnt;
726 	lck_rw_done(&vsockinfo.all_lock);
727 
728 	const size_t xpcb_len = sizeof(struct xvsockpcb);
729 	struct xvsockpgen xvg;
730 
731 	/*
732 	 * The process of preparing the PCB list is too time-consuming and
733 	 * resource-intensive to repeat twice on every request.
734 	 */
735 	if (req->oldptr == USER_ADDR_NULL) {
736 		req->oldidx = (size_t)(2 * sizeof(xvg) + (n + n / 8) * xpcb_len);
737 		return 0;
738 	}
739 
740 	if (req->newptr != USER_ADDR_NULL) {
741 		return EPERM;
742 	}
743 
744 	bzero(&xvg, sizeof(xvg));
745 	xvg.xvg_len = sizeof(xvg);
746 	xvg.xvg_count = n;
747 	xvg.xvg_gen = gen_count;
748 	xvg.xvg_sogen = so_gencnt;
749 	error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
750 	if (error) {
751 		return error;
752 	}
753 
754 	// Return if no sockets exist.
755 	if (n == 0) {
756 		return 0;
757 	}
758 
759 	lck_rw_lock_shared(&vsockinfo.all_lock);
760 
761 	n = 0;
762 	struct vsockpcb *pcb = NULL;
763 	TAILQ_FOREACH(pcb, &vsockinfo.all, all) {
764 		// Bail if there is not enough user buffer for this next socket.
765 		if (req->oldlen - req->oldidx - sizeof(xvg) < xpcb_len) {
766 			break;
767 		}
768 
769 		// Populate the socket structure.
770 		socket_lock(pcb->so, 1);
771 		if (pcb->vsock_gencnt <= gen_count) {
772 			struct xvsockpcb xpcb;
773 			bzero(&xpcb, xpcb_len);
774 			xpcb.xv_len = xpcb_len;
775 			xpcb.xv_vsockpp = (uint64_t)VM_KERNEL_ADDRHASH(pcb);
776 			xpcb.xvp_local_cid = pcb->local_address.cid;
777 			xpcb.xvp_local_port = pcb->local_address.port;
778 			xpcb.xvp_remote_cid = pcb->remote_address.cid;
779 			xpcb.xvp_remote_port = pcb->remote_address.port;
780 			xpcb.xvp_rxcnt = pcb->fwd_cnt;
781 			xpcb.xvp_txcnt = pcb->tx_cnt;
782 			xpcb.xvp_peer_rxhiwat = pcb->peer_buf_alloc;
783 			xpcb.xvp_peer_rxcnt = pcb->peer_fwd_cnt;
784 			xpcb.xvp_last_pid = pcb->so->last_pid;
785 			xpcb.xvp_gencnt = pcb->vsock_gencnt;
786 			if (pcb->so) {
787 				sotoxsocket(pcb->so, &xpcb.xv_socket);
788 			}
789 			socket_unlock(pcb->so, 1);
790 
791 			error = SYSCTL_OUT(req, &xpcb, xpcb_len);
792 			if (error != 0) {
793 				break;
794 			}
795 			n++;
796 		} else {
797 			socket_unlock(pcb->so, 1);
798 		}
799 	}
800 
801 	// Update the generation count to match the sockets being returned.
802 	gen_count = vsockinfo.vsock_gencnt;
803 
804 	lck_rw_done(&vsockinfo.all_lock);
805 
806 	if (!error) {
807 		/*
808 		 * Give the user an updated idea of our state.
809 		 * If the generation differs from what we told
810 		 * her before, she knows that something happened
811 		 * while we were processing this request, and it
812 		 * might be necessary to retry.
813 		 */
814 		bzero(&xvg, sizeof(xvg));
815 		xvg.xvg_len = sizeof(xvg);
816 		xvg.xvg_count = n;
817 		xvg.xvg_gen = gen_count;
818 		xvg.xvg_sogen = so_gencnt;
819 		error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
820 	}
821 
822 	return error;
823 }
824 
825 #ifdef SYSCTL_DECL
826 SYSCTL_NODE(_net, OID_AUTO, vsock, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock");
827 SYSCTL_UINT(_net_vsock, OID_AUTO, sendspace, CTLFLAG_RW | CTLFLAG_LOCKED,
828     &vsock_sendspace, 0, "Maximum outgoing vsock datagram size");
829 SYSCTL_UINT(_net_vsock, OID_AUTO, recvspace, CTLFLAG_RW | CTLFLAG_LOCKED,
830     &vsock_recvspace, 0, "Maximum incoming vsock datagram size");
831 SYSCTL_PROC(_net_vsock, OID_AUTO, pcblist,
832     CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
833     (caddr_t)(long)SOCK_STREAM, 0, vsock_pcblist, "S,xvsockpcb",
834     "List of active vsock sockets");
835 #endif
836 
837 /* VSock Protocol */
838 
839 static int
vsock_attach(struct socket * so,int proto,struct proc * p)840 vsock_attach(struct socket *so, int proto, struct proc *p)
841 {
842 	#pragma unused(proto, p)
843 
844 	// Reserve send and receive buffers.
845 	errno_t error = soreserve(so, vsock_sendspace, vsock_recvspace);
846 	if (error) {
847 		return error;
848 	}
849 
850 	// Attach should only be run once per socket.
851 	struct vsockpcb *pcb = sotovsockpcb(so);
852 	if (pcb) {
853 		return EINVAL;
854 	}
855 
856 	// Get the transport for this socket.
857 	struct vsock_transport *transport = os_atomic_load(&the_vsock_transport, relaxed);
858 	if (transport == NULL) {
859 		return ENODEV;
860 	}
861 
862 	// Initialize the vsock protocol control block.
863 	pcb = zalloc_flags(vsockpcb_zone, Z_WAITOK | Z_ZERO | Z_NOFAIL);
864 	pcb->so = so;
865 	pcb->transport = transport;
866 	pcb->local_address = (struct vsock_address) {
867 		.cid = VMADDR_CID_ANY,
868 		.port = VMADDR_PORT_ANY
869 	};
870 	pcb->remote_address = (struct vsock_address) {
871 		.cid = VMADDR_CID_ANY,
872 		.port = VMADDR_PORT_ANY
873 	};
874 	so->so_pcb = pcb;
875 
876 	// Tell the transport that this socket has attached.
877 	error = transport->attach_socket(transport->provider);
878 	if (error) {
879 		zfree(vsockpcb_zone, pcb);
880 		so->so_pcb = NULL;
881 		return error;
882 	}
883 
884 	// Add to the list of all vsock sockets.
885 	lck_rw_lock_exclusive(&vsockinfo.all_lock);
886 	TAILQ_INSERT_TAIL(&vsockinfo.all, pcb, all);
887 	vsockinfo.all_pcb_count++;
888 	pcb->vsock_gencnt = ++vsockinfo.vsock_gencnt;
889 	lck_rw_done(&vsockinfo.all_lock);
890 
891 	return 0;
892 }
893 
894 static int
vsock_control(struct socket * so,u_long cmd,caddr_t data,struct ifnet * ifp,struct proc * p)895 vsock_control(struct socket *so, u_long cmd, caddr_t data, struct ifnet *ifp, struct proc *p)
896 {
897 	#pragma unused(ifp)
898 
899 	VERIFY(so != NULL || p == kernproc);
900 
901 	if (cmd != IOCTL_VM_SOCKETS_GET_LOCAL_CID) {
902 		return EINVAL;
903 	}
904 
905 	struct vsock_transport *transport;
906 	if (so) {
907 		struct vsockpcb *pcb = sotovsockpcb(so);
908 		if (pcb == NULL) {
909 			return EINVAL;
910 		}
911 		transport = pcb->transport;
912 	} else {
913 		transport = os_atomic_load(&the_vsock_transport, relaxed);
914 	}
915 
916 	if (transport == NULL) {
917 		return ENODEV;
918 	}
919 
920 	uint32_t transport_cid;
921 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
922 	if (error) {
923 		return error;
924 	}
925 
926 	memcpy(data, &transport_cid, sizeof(transport_cid));
927 
928 	return 0;
929 }
930 
931 static int
vsock_detach(struct socket * so)932 vsock_detach(struct socket *so)
933 {
934 	struct vsockpcb *pcb = sotovsockpcb(so);
935 	if (pcb == NULL) {
936 		return EINVAL;
937 	}
938 
939 	vsock_unbind_pcb(pcb, false);
940 
941 	// Tell the transport that this socket has detached.
942 	struct vsock_transport *transport = pcb->transport;
943 	errno_t error = transport->detach_socket(transport->provider);
944 	if (error) {
945 		return error;
946 	}
947 
948 	// Remove from the list of all vsock sockets.
949 	lck_rw_lock_exclusive(&vsockinfo.all_lock);
950 	TAILQ_REMOVE(&vsockinfo.all, pcb, all);
951 	pcb->all.tqe_next = NULL;
952 	pcb->all.tqe_prev = NULL;
953 	vsockinfo.all_pcb_count--;
954 	vsockinfo.vsock_gencnt++;
955 	lck_rw_done(&vsockinfo.all_lock);
956 
957 	// Deallocate any resources.
958 	zfree(vsockpcb_zone, pcb);
959 	so->so_pcb = 0;
960 	so->so_flags |= SOF_PCBCLEARING;
961 	sofree(so);
962 
963 	return 0;
964 }
965 
966 static int
vsock_abort(struct socket * so)967 vsock_abort(struct socket *so)
968 {
969 	soisdisconnected(so);
970 	return vsock_detach(so);
971 }
972 
973 static int
vsock_bind(struct socket * so,struct sockaddr * nam,struct proc * p)974 vsock_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
975 {
976 	#pragma unused(p)
977 
978 	struct vsockpcb *pcb = sotovsockpcb(so);
979 	if (pcb == NULL) {
980 		return EINVAL;
981 	}
982 
983 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
984 
985 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr);
986 	if (error) {
987 		return error;
988 	}
989 
990 	struct vsock_address laddr = (struct vsock_address) {
991 		.cid = addr->svm_cid,
992 		.port = addr->svm_port,
993 	};
994 
995 	struct vsock_address raddr = (struct vsock_address) {
996 		.cid = VMADDR_CID_ANY,
997 		.port = VMADDR_PORT_ANY,
998 	};
999 
1000 	error = vsock_bind_address(pcb, laddr, raddr);
1001 	if (error) {
1002 		return error;
1003 	}
1004 
1005 	return 0;
1006 }
1007 
1008 static int
vsock_listen(struct socket * so,struct proc * p)1009 vsock_listen(struct socket *so, struct proc *p)
1010 {
1011 	#pragma unused(p)
1012 
1013 	struct vsockpcb *pcb = sotovsockpcb(so);
1014 	if (pcb == NULL) {
1015 		return EINVAL;
1016 	}
1017 
1018 	// Only stream is supported currently.
1019 	if (so->so_type != SOCK_STREAM) {
1020 		return EAFNOSUPPORT;
1021 	}
1022 
1023 	struct vsock_address *addr = &pcb->local_address;
1024 
1025 	if (addr->port == VMADDR_CID_ANY) {
1026 		return EFAULT;
1027 	}
1028 
1029 	struct vsock_transport *transport = pcb->transport;
1030 	uint32_t transport_cid;
1031 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
1032 	if (error) {
1033 		return error;
1034 	}
1035 
1036 	// Can listen on the transport's cid or any.
1037 	if (addr->cid != transport_cid && addr->cid != VMADDR_CID_ANY) {
1038 		return EFAULT;
1039 	}
1040 
1041 	return 0;
1042 }
1043 
1044 static int
vsock_accept(struct socket * so,struct sockaddr ** nam)1045 vsock_accept(struct socket *so, struct sockaddr **nam)
1046 {
1047 	struct vsockpcb *pcb = sotovsockpcb(so);
1048 	if (pcb == NULL) {
1049 		return EINVAL;
1050 	}
1051 
1052 	// Do not accept disconnected sockets.
1053 	if (so->so_state & SS_ISDISCONNECTED) {
1054 		return ECONNABORTED;
1055 	}
1056 
1057 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1058 
1059 	return 0;
1060 }
1061 
1062 static int
vsock_connect(struct socket * so,struct sockaddr * nam,struct proc * p)1063 vsock_connect(struct socket *so, struct sockaddr *nam, struct proc *p)
1064 {
1065 	#pragma unused(p)
1066 
1067 	struct vsockpcb *pcb = sotovsockpcb(so);
1068 	if (pcb == NULL) {
1069 		return EINVAL;
1070 	}
1071 
1072 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1073 
1074 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr);
1075 	if (error) {
1076 		return error;
1077 	}
1078 
1079 	uint32_t transport_cid;
1080 	struct vsock_transport *transport = pcb->transport;
1081 	error = transport->get_cid(transport->provider, &transport_cid);
1082 	if (error) {
1083 		return error;
1084 	}
1085 
1086 	// Only supporting connections to the host, hypervisor, or self for now.
1087 	if (addr->svm_cid != VMADDR_CID_HOST &&
1088 	    addr->svm_cid != VMADDR_CID_HYPERVISOR &&
1089 	    addr->svm_cid != transport_cid) {
1090 		return EFAULT;
1091 	}
1092 
1093 	soisconnecting(so);
1094 
1095 	// Set the remote and local address.
1096 	struct vsock_address remote_addr = (struct vsock_address) {
1097 		.cid = addr->svm_cid,
1098 		.port = addr->svm_port,
1099 	};
1100 
1101 	struct vsock_address local_addr = (struct vsock_address) {
1102 		.cid = transport_cid,
1103 		.port = VMADDR_PORT_ANY,
1104 	};
1105 
1106 	// Bind to the address.
1107 	error = vsock_bind_address(pcb, local_addr, remote_addr);
1108 	if (error) {
1109 		goto cleanup;
1110 	}
1111 
1112 	// Attempt a connection using the socket's transport.
1113 	error = vsock_pcb_connect(pcb);
1114 	if (error) {
1115 		goto cleanup;
1116 	}
1117 
1118 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1119 		// Don't wait for peer's response if non-blocking.
1120 		if (so->so_state & SS_NBIO) {
1121 			error = EINPROGRESS;
1122 			goto done;
1123 		}
1124 
1125 		struct timespec ts = (struct timespec) {
1126 			.tv_sec = so->so_snd.sb_timeo.tv_sec,
1127 			.tv_nsec = so->so_snd.sb_timeo.tv_usec * 1000,
1128 		};
1129 
1130 		lck_mtx_t *mutex_held;
1131 		if (so->so_proto->pr_getlock != NULL) {
1132 			mutex_held = (*so->so_proto->pr_getlock)(so, PR_F_WILLUNLOCK);
1133 		} else {
1134 			mutex_held = so->so_proto->pr_domain->dom_mtx;
1135 		}
1136 
1137 		// Wait until we receive a response to the connect request.
1138 		error = msleep((caddr_t)&so->so_timeo, mutex_held, PSOCK | PCATCH, "vsock_connect", &ts);
1139 		if (error) {
1140 			if (error == EAGAIN) {
1141 				error = ETIMEDOUT;
1142 			}
1143 			goto cleanup;
1144 		}
1145 	}
1146 
1147 cleanup:
1148 	if (so->so_error && !error) {
1149 		error = so->so_error;
1150 		so->so_error = 0;
1151 	}
1152 	if (!error) {
1153 		error = !(so->so_state & SS_ISCONNECTED);
1154 	}
1155 	if (error) {
1156 		vsock_unbind_pcb(pcb, false);
1157 	}
1158 
1159 done:
1160 	return error;
1161 }
1162 
1163 static int
vsock_disconnect(struct socket * so)1164 vsock_disconnect(struct socket *so)
1165 {
1166 	struct vsockpcb *pcb = sotovsockpcb(so);
1167 	if (pcb == NULL) {
1168 		return EINVAL;
1169 	}
1170 
1171 	return vsock_disconnect_pcb(pcb);
1172 }
1173 
1174 static int
vsock_sockaddr(struct socket * so,struct sockaddr ** nam)1175 vsock_sockaddr(struct socket *so, struct sockaddr **nam)
1176 {
1177 	struct vsockpcb *pcb = sotovsockpcb(so);
1178 	if (pcb == NULL) {
1179 		return EINVAL;
1180 	}
1181 
1182 	*nam = vsock_new_sockaddr(&pcb->local_address);
1183 
1184 	return 0;
1185 }
1186 
1187 static int
vsock_peeraddr(struct socket * so,struct sockaddr ** nam)1188 vsock_peeraddr(struct socket *so, struct sockaddr **nam)
1189 {
1190 	struct vsockpcb *pcb = sotovsockpcb(so);
1191 	if (pcb == NULL) {
1192 		return EINVAL;
1193 	}
1194 
1195 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1196 
1197 	return 0;
1198 }
1199 
1200 static int
vsock_send(struct socket * so,int flags,struct mbuf * m,struct sockaddr * nam,struct mbuf * control,proc_t p)1201 vsock_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *nam, struct mbuf *control, proc_t p)
1202 {
1203 	#pragma unused(flags, nam, p)
1204 
1205 	struct vsockpcb *pcb = sotovsockpcb(so);
1206 	if (pcb == NULL || m == NULL) {
1207 		return EINVAL;
1208 	}
1209 
1210 	if (control != NULL) {
1211 		m_freem(control);
1212 		return EOPNOTSUPP;
1213 	}
1214 
1215 	// Ensure this socket is connected.
1216 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1217 		if (m != NULL) {
1218 			mbuf_freem_list(m);
1219 		}
1220 		return EPERM;
1221 	}
1222 
1223 	errno_t error;
1224 
1225 	// rdar://84098487 (SEED: Web: Virtio-socket sent data lost after 128KB)
1226 	// For writes larger than the default `sosendmaxchain` of 65536, vsock_send() is called multiple times per write().
1227 	// Only the first call to vsock_send() is passed a valid mbuf packet, while subsequent calls are not marked as a packet
1228 	// with a valid length. We should mark all mbufs as a packet and set the correct packet length so that the downstream
1229 	// socket transport layer can correctly generate physical segments.
1230 	if (!(mbuf_flags(m) & MBUF_PKTHDR)) {
1231 		if (!(mbuf_flags(m) & M_EXT)) {
1232 			struct mbuf *header = NULL;
1233 			MGETHDR(header, M_WAITOK, MT_HEADER);
1234 			if (header == NULL) {
1235 				if (m != NULL) {
1236 					mbuf_freem_list(m);
1237 				}
1238 				return ENOBUFS;
1239 			}
1240 			header->m_next = m;
1241 			m = header;
1242 		} else {
1243 			mbuf_setflags(m, mbuf_flags(m) | MBUF_PKTHDR);
1244 		}
1245 
1246 		size_t len = 0;
1247 		struct mbuf *next = m;
1248 		while (next) {
1249 			len += mbuf_len(next);
1250 			next = mbuf_next(next);
1251 		}
1252 		mbuf_pkthdr_setlen(m, len);
1253 	}
1254 
1255 	const size_t len = mbuf_pkthdr_len(m);
1256 	uint32_t free_space = vsock_get_peer_space(pcb);
1257 
1258 	// Ensure the peer has enough space in their receive buffer.
1259 	while (len > free_space) {
1260 		// Record the number of free peer bytes necessary before we can send.
1261 		if (len > pcb->waiting_send_size) {
1262 			pcb->waiting_send_size = len;
1263 		}
1264 
1265 		// Send a credit request.
1266 		error = vsock_pcb_credit_request(pcb);
1267 		if (error) {
1268 			if (m != NULL) {
1269 				mbuf_freem_list(m);
1270 			}
1271 			return error;
1272 		}
1273 
1274 		// Check again in case free space was automatically updated in loopback case.
1275 		free_space = vsock_get_peer_space(pcb);
1276 		if (len <= free_space) {
1277 			pcb->waiting_send_size = 0;
1278 			break;
1279 		}
1280 
1281 		// Bail if this is a non-blocking socket.
1282 		if (so->so_state & SS_NBIO) {
1283 			if (m != NULL) {
1284 				mbuf_freem_list(m);
1285 			}
1286 			return EWOULDBLOCK;
1287 		}
1288 
1289 		// Wait until our peer has enough free space in their receive buffer.
1290 		error = sbwait(&so->so_snd);
1291 		pcb->waiting_send_size = 0;
1292 		if (error) {
1293 			if (m != NULL) {
1294 				mbuf_freem_list(m);
1295 			}
1296 			return error;
1297 		}
1298 
1299 		// Bail if an error occured or we can't send more.
1300 		if (so->so_state & SS_CANTSENDMORE) {
1301 			if (m != NULL) {
1302 				mbuf_freem_list(m);
1303 			}
1304 			return EPIPE;
1305 		} else if (so->so_error) {
1306 			error = so->so_error;
1307 			so->so_error = 0;
1308 			if (m != NULL) {
1309 				mbuf_freem_list(m);
1310 			}
1311 			return error;
1312 		}
1313 
1314 		free_space = vsock_get_peer_space(pcb);
1315 	}
1316 
1317 	// Send a payload over the transport.
1318 	error = vsock_pcb_send(pcb, m);
1319 	if (error) {
1320 		return error;
1321 	}
1322 
1323 	pcb->tx_cnt += len;
1324 
1325 	return 0;
1326 }
1327 
1328 static int
vsock_shutdown(struct socket * so)1329 vsock_shutdown(struct socket *so)
1330 {
1331 	struct vsockpcb *pcb = sotovsockpcb(so);
1332 	if (pcb == NULL) {
1333 		return EINVAL;
1334 	}
1335 
1336 	socantsendmore(so);
1337 
1338 	// Tell peer we will no longer send.
1339 	errno_t error = vsock_pcb_shutdown_send(pcb);
1340 	if (error) {
1341 		return error;
1342 	}
1343 
1344 	return 0;
1345 }
1346 
1347 static int
vsock_soreceive(struct socket * so,struct sockaddr ** psa,struct uio * uio,struct mbuf ** mp0,struct mbuf ** controlp,int * flagsp)1348 vsock_soreceive(struct socket *so, struct sockaddr **psa, struct uio *uio,
1349     struct mbuf **mp0, struct mbuf **controlp, int *flagsp)
1350 {
1351 	struct vsockpcb *pcb = sotovsockpcb(so);
1352 	if (pcb == NULL) {
1353 		return EINVAL;
1354 	}
1355 
1356 	user_ssize_t length = uio_resid(uio);
1357 	int result = soreceive(so, psa, uio, mp0, controlp, flagsp);
1358 	length -= uio_resid(uio);
1359 
1360 	socket_lock(so, 1);
1361 
1362 	pcb->fwd_cnt += length;
1363 
1364 	const uint32_t threshold = VSOCK_MAX_PACKET_SIZE;
1365 
1366 	// Send a credit update if is possible that the peer will no longer send.
1367 	if ((pcb->fwd_cnt - pcb->last_fwd_cnt + threshold) >= pcb->last_buf_alloc) {
1368 		errno_t error = vsock_pcb_credit_update(pcb);
1369 		if (!result && error) {
1370 			result = error;
1371 		}
1372 	}
1373 
1374 	socket_unlock(so, 1);
1375 
1376 	return result;
1377 }
1378 
1379 static struct pr_usrreqs vsock_usrreqs = {
1380 	.pru_abort =            vsock_abort,
1381 	.pru_attach =           vsock_attach,
1382 	.pru_control =          vsock_control,
1383 	.pru_detach =           vsock_detach,
1384 	.pru_bind =             vsock_bind,
1385 	.pru_listen =           vsock_listen,
1386 	.pru_accept =           vsock_accept,
1387 	.pru_connect =          vsock_connect,
1388 	.pru_disconnect =       vsock_disconnect,
1389 	.pru_send =             vsock_send,
1390 	.pru_shutdown =         vsock_shutdown,
1391 	.pru_sockaddr =         vsock_sockaddr,
1392 	.pru_peeraddr =         vsock_peeraddr,
1393 	.pru_sosend =           sosend,
1394 	.pru_soreceive =        vsock_soreceive,
1395 };
1396 
1397 static void
vsock_init(struct protosw * pp,struct domain * dp)1398 vsock_init(struct protosw *pp, struct domain *dp)
1399 {
1400 	#pragma unused(dp)
1401 
1402 	static int vsock_initialized = 0;
1403 	VERIFY((pp->pr_flags & (PR_INITIALIZED | PR_ATTACHED)) == PR_ATTACHED);
1404 	if (!os_atomic_cmpxchg((volatile int *)&vsock_initialized, 0, 1, acq_rel)) {
1405 		return;
1406 	}
1407 
1408 	// Setup VSock protocol info struct.
1409 	lck_rw_init(&vsockinfo.all_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1410 	lck_rw_init(&vsockinfo.bound_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1411 	lck_mtx_init(&vsockinfo.port_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1412 	TAILQ_INIT(&vsockinfo.all);
1413 	LIST_INIT(&vsockinfo.bound);
1414 	vsockinfo.last_port = VMADDR_PORT_ANY;
1415 }
1416 
1417 static struct protosw vsocksw[] = {
1418 	{
1419 		.pr_type =              SOCK_STREAM,
1420 		.pr_protocol =          0,
1421 		.pr_flags =             PR_CONNREQUIRED | PR_WANTRCVD,
1422 		.pr_init =              vsock_init,
1423 		.pr_usrreqs =           &vsock_usrreqs,
1424 	}
1425 };
1426 
1427 static const int vsock_proto_count = (sizeof(vsocksw) / sizeof(struct protosw));
1428 
1429 /* VSock Domain */
1430 
1431 static struct domain *vsock_domain = NULL;
1432 
1433 static void
vsock_dinit(struct domain * dp)1434 vsock_dinit(struct domain *dp)
1435 {
1436 	// The VSock domain is initialized with a singleton pattern.
1437 	VERIFY(!(dp->dom_flags & DOM_INITIALIZED));
1438 	VERIFY(vsock_domain == NULL);
1439 	vsock_domain = dp;
1440 
1441 	// Add protocols and initialize.
1442 	for (int i = 0; i < vsock_proto_count; i++) {
1443 		net_add_proto((struct protosw *)&vsocksw[i], dp, 1);
1444 	}
1445 }
1446 
1447 struct domain vsockdomain_s = {
1448 	.dom_family =           PF_VSOCK,
1449 	.dom_name =             "vsock",
1450 	.dom_init =             vsock_dinit,
1451 	.dom_maxrtkey =         sizeof(struct sockaddr_vm),
1452 	.dom_protohdrlen =      sizeof(struct sockaddr_vm),
1453 };
1454