xref: /xnu-11417.121.6/bsd/kern/vsock_domain.c (revision a1e26a70f38d1d7daa7b49b258e2f8538ad81650)
1 /*
2  * Copyright (c) 2020 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <sys/domain.h>
30 #include <sys/socket.h>
31 #include <sys/protosw.h>
32 #include <sys/mcache.h>
33 #include <sys/systm.h>
34 #include <sys/sysctl.h>
35 #include <sys/random.h>
36 #include <sys/mbuf.h>
37 #include <sys/vsock_domain.h>
38 #include <sys/vsock_transport.h>
39 #include <kern/task.h>
40 #include <kern/zalloc.h>
41 #include <kern/locks.h>
42 #include <machine/atomic.h>
43 
44 #define sotovsockpcb(so) ((struct vsockpcb *)(so)->so_pcb)
45 
46 #define VSOCK_PORT_RESERVED 1024
47 
48 /* VSock Protocol Globals */
49 
50 static struct vsock_transport * _Atomic the_vsock_transport = NULL;
51 static ZONE_DEFINE(vsockpcb_zone, "vsockpcbzone",
52     sizeof(struct vsockpcb), ZC_NONE);
53 static LCK_GRP_DECLARE(vsock_lock_grp, "vsock");
54 static struct vsockpcbinfo vsockinfo;
55 
56 static uint32_t vsock_sendspace = VSOCK_MAX_PACKET_SIZE * 8;
57 static uint32_t vsock_recvspace = VSOCK_MAX_PACKET_SIZE * 8;
58 
59 /* VSock PCB Helpers */
60 
61 static uint32_t
vsock_get_peer_space(struct vsockpcb * pcb)62 vsock_get_peer_space(struct vsockpcb *pcb)
63 {
64 	return pcb->peer_buf_alloc - (pcb->tx_cnt - pcb->peer_fwd_cnt);
65 }
66 
67 static struct vsockpcb *
vsock_get_matching_pcb(struct vsock_address src,struct vsock_address dst)68 vsock_get_matching_pcb(struct vsock_address src, struct vsock_address dst)
69 {
70 	struct vsockpcb *preferred = NULL;
71 	struct vsockpcb *match = NULL;
72 	struct vsockpcb *pcb = NULL;
73 
74 	lck_rw_lock_shared(&vsockinfo.bound_lock);
75 	LIST_FOREACH(pcb, &vsockinfo.bound, bound) {
76 		// Source cid and port must match. Only destination port must match. (Allows for a changing CID during migration)
77 		socket_lock(pcb->so, 1);
78 		if ((pcb->so->so_state & SS_ISCONNECTED || pcb->so->so_state & SS_ISCONNECTING) &&
79 		    pcb->local_address.cid == src.cid && pcb->local_address.port == src.port &&
80 		    pcb->remote_address.port == dst.port) {
81 			preferred = pcb;
82 			break;
83 		} else if ((pcb->local_address.cid == src.cid || pcb->local_address.cid == VMADDR_CID_ANY) &&
84 		    pcb->local_address.port == src.port) {
85 			match = pcb;
86 		}
87 		socket_unlock(pcb->so, 1);
88 	}
89 	if (!preferred && match) {
90 		socket_lock(match->so, 1);
91 		preferred = match;
92 	}
93 	lck_rw_done(&vsockinfo.bound_lock);
94 
95 	return preferred;
96 }
97 
98 static errno_t
vsock_bind_address_if_free(struct vsockpcb * pcb,uint32_t local_cid,uint32_t local_port,uint32_t remote_cid,uint32_t remote_port)99 vsock_bind_address_if_free(struct vsockpcb *pcb, uint32_t local_cid, uint32_t local_port, uint32_t remote_cid, uint32_t remote_port)
100 {
101 	socket_lock_assert_owned(pcb->so);
102 
103 	// Privileged ports.
104 	if (local_port != VMADDR_PORT_ANY && local_port < VSOCK_PORT_RESERVED &&
105 	    current_task() != kernel_task && proc_suser(current_proc()) != 0) {
106 		return EACCES;
107 	}
108 
109 	bool taken = false;
110 	const bool check_remote = (remote_cid != VMADDR_CID_ANY && remote_port != VMADDR_PORT_ANY);
111 
112 	struct vsockpcb *pcb_match = NULL;
113 
114 	socket_unlock(pcb->so, 0);
115 	lck_rw_lock_exclusive(&vsockinfo.bound_lock);
116 	LIST_FOREACH(pcb_match, &vsockinfo.bound, bound) {
117 		socket_lock(pcb_match->so, 1);
118 		if (pcb == pcb_match ||
119 		    (!check_remote && pcb_match->local_address.port == local_port) ||
120 		    (check_remote && pcb_match->local_address.port == local_port &&
121 		    pcb_match->remote_address.cid == remote_cid && pcb_match->remote_address.port == remote_port)) {
122 			socket_unlock(pcb_match->so, 1);
123 			taken = true;
124 			break;
125 		}
126 		socket_unlock(pcb_match->so, 1);
127 	}
128 	socket_lock(pcb->so, 0);
129 	if (!taken) {
130 		pcb->local_address = (struct vsock_address) { .cid = local_cid, .port = local_port };
131 		pcb->remote_address = (struct vsock_address) { .cid = remote_cid, .port = remote_port };
132 		LIST_INSERT_HEAD(&vsockinfo.bound, pcb, bound);
133 	}
134 	lck_rw_done(&vsockinfo.bound_lock);
135 
136 	return taken ? EADDRINUSE : 0;
137 }
138 
139 static errno_t
vsock_bind_address(struct vsockpcb * pcb,struct vsock_address laddr,struct vsock_address raddr)140 vsock_bind_address(struct vsockpcb *pcb, struct vsock_address laddr, struct vsock_address raddr)
141 {
142 	if (!pcb) {
143 		return EINVAL;
144 	}
145 
146 	socket_lock_assert_owned(pcb->so);
147 
148 	// Certain CIDs are reserved.
149 	if (laddr.cid == VMADDR_CID_HYPERVISOR || laddr.cid == VMADDR_CID_RESERVED || laddr.cid == VMADDR_CID_HOST) {
150 		return EADDRNOTAVAIL;
151 	}
152 
153 	// Remote address must be fully specified or not specified at all.
154 	if ((raddr.cid == VMADDR_CID_ANY) ^ (raddr.port == VMADDR_PORT_ANY)) {
155 		return EINVAL;
156 	}
157 
158 	// Cannot bind if already bound.
159 	if (pcb->local_address.port != VMADDR_PORT_ANY) {
160 		return EINVAL;
161 	}
162 
163 	uint32_t transport_cid;
164 	struct vsock_transport *transport = pcb->transport;
165 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
166 	if (error) {
167 		return error;
168 	}
169 
170 	// Local CID must be this transport's CID or any.
171 	if (laddr.cid != transport_cid && laddr.cid != VMADDR_CID_ANY) {
172 		return EINVAL;
173 	}
174 
175 	if (laddr.port != VMADDR_PORT_ANY) {
176 		error = vsock_bind_address_if_free(pcb, laddr.cid, laddr.port, raddr.cid, raddr.port);
177 	} else {
178 		socket_unlock(pcb->so, 0);
179 		lck_mtx_lock(&vsockinfo.port_lock);
180 		socket_lock(pcb->so, 0);
181 
182 		const uint32_t first = VSOCK_PORT_RESERVED;
183 		const uint32_t last = VMADDR_PORT_ANY - 1;
184 		uint32_t count = last - first + 1;
185 		uint32_t *last_port = &vsockinfo.last_port;
186 
187 		if (pcb->so->so_flags & SOF_BINDRANDOMPORT) {
188 			uint32_t random = 0;
189 			read_frandom(&random, sizeof(random));
190 			*last_port = first + (random % count);
191 		}
192 
193 		do {
194 			if (count == 0) {
195 				lck_mtx_unlock(&vsockinfo.port_lock);
196 				return EADDRNOTAVAIL;
197 			}
198 			count--;
199 
200 			++*last_port;
201 			if (*last_port < first || *last_port > last) {
202 				*last_port = first;
203 			}
204 
205 			error = vsock_bind_address_if_free(pcb, laddr.cid, *last_port, raddr.cid, raddr.port);
206 		} while (error);
207 
208 		lck_mtx_unlock(&vsockinfo.port_lock);
209 	}
210 
211 	return error;
212 }
213 
214 static void
vsock_unbind_pcb_locked(struct vsockpcb * pcb,bool is_locked)215 vsock_unbind_pcb_locked(struct vsockpcb *pcb, bool is_locked)
216 {
217 	if (!pcb) {
218 		return;
219 	}
220 
221 	struct socket *so = pcb->so;
222 	socket_lock_assert_owned(so);
223 
224 	// Bail if disconnect and already unbound.
225 	if (so->so_state & SS_ISDISCONNECTED) {
226 		assert(pcb->bound.le_next == NULL);
227 		assert(pcb->bound.le_prev == NULL);
228 		return;
229 	}
230 
231 	if (!is_locked) {
232 		socket_unlock(so, 0);
233 		lck_rw_lock_exclusive(&vsockinfo.bound_lock);
234 		socket_lock(so, 0);
235 
236 		// Case where some other thread also called unbind() on this socket while waiting to acquire its lock.
237 		if (!pcb->bound.le_prev) {
238 			soisdisconnected(so);
239 			lck_rw_done(&vsockinfo.bound_lock);
240 			return;
241 		}
242 	}
243 
244 	soisdisconnected(so);
245 
246 	LIST_REMOVE(pcb, bound);
247 	pcb->bound.le_next = NULL;
248 	pcb->bound.le_prev = NULL;
249 
250 	if (!is_locked) {
251 		lck_rw_done(&vsockinfo.bound_lock);
252 	}
253 }
254 
255 static void
vsock_unbind_pcb(struct vsockpcb * pcb)256 vsock_unbind_pcb(struct vsockpcb *pcb)
257 {
258 	vsock_unbind_pcb_locked(pcb, false);
259 }
260 
261 static struct sockaddr *
vsock_new_sockaddr(struct vsock_address * address)262 vsock_new_sockaddr(struct vsock_address *address)
263 {
264 	if (!address) {
265 		return NULL;
266 	}
267 
268 	struct sockaddr_vm *addr;
269 	addr = (struct sockaddr_vm *)alloc_sockaddr(sizeof(*addr),
270 	    Z_WAITOK | Z_NOFAIL);
271 
272 	addr->svm_family = AF_VSOCK;
273 	addr->svm_port = address->port;
274 	addr->svm_cid = address->cid;
275 
276 	return (struct sockaddr *)addr;
277 }
278 
279 static errno_t
vsock_pcb_send_message(struct vsockpcb * pcb,enum vsock_operation operation,mbuf_t m)280 vsock_pcb_send_message(struct vsockpcb *pcb, enum vsock_operation operation, mbuf_t m)
281 {
282 	if (!pcb) {
283 		if (m != NULL) {
284 			mbuf_freem_list(m);
285 		}
286 		return EINVAL;
287 	}
288 
289 	socket_lock_assert_owned(pcb->so);
290 
291 	errno_t error;
292 
293 	struct vsock_address dst = pcb->remote_address;
294 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
295 		if (m != NULL) {
296 			mbuf_freem_list(m);
297 		}
298 		return EINVAL;
299 	}
300 
301 	struct vsock_address src = pcb->local_address;
302 	if (src.cid == VMADDR_CID_ANY) {
303 		uint32_t transport_cid;
304 		struct vsock_transport *transport = pcb->transport;
305 		error = transport->get_cid(transport->provider, &transport_cid);
306 		if (error) {
307 			if (m != NULL) {
308 				mbuf_freem_list(m);
309 			}
310 			return error;
311 		}
312 		src.cid = transport_cid;
313 	}
314 
315 	uint32_t buf_alloc = pcb->so->so_rcv.sb_hiwat;
316 	uint32_t fwd_cnt = pcb->fwd_cnt;
317 
318 	if (src.cid == dst.cid) {
319 		pcb->last_buf_alloc = buf_alloc;
320 		pcb->last_fwd_cnt = fwd_cnt;
321 
322 		socket_unlock(pcb->so, 0);
323 		error = vsock_put_message(src, dst, operation, buf_alloc, fwd_cnt, m);
324 		socket_lock(pcb->so, 0);
325 	} else {
326 		struct vsock_transport *transport = pcb->transport;
327 		error = transport->put_message(transport->provider, src, dst, operation, buf_alloc, fwd_cnt, m);
328 
329 		if (!error) {
330 			pcb->last_buf_alloc = buf_alloc;
331 			pcb->last_fwd_cnt = fwd_cnt;
332 		}
333 	}
334 
335 	return error;
336 }
337 
338 static errno_t
vsock_pcb_reset_address(struct vsock_address src,struct vsock_address dst)339 vsock_pcb_reset_address(struct vsock_address src, struct vsock_address dst)
340 {
341 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
342 		return EINVAL;
343 	}
344 
345 	errno_t error = 0;
346 	struct vsock_transport *transport = NULL;
347 
348 	if (src.cid == VMADDR_CID_ANY) {
349 		transport = os_atomic_load(&the_vsock_transport, relaxed);
350 		if (transport == NULL) {
351 			return ENODEV;
352 		}
353 
354 		uint32_t transport_cid;
355 		error = transport->get_cid(transport->provider, &transport_cid);
356 		if (error) {
357 			return error;
358 		}
359 		src.cid = transport_cid;
360 	}
361 
362 	if (src.cid == dst.cid) {
363 		// Reset both sockets.
364 		struct vsockpcb *pcb = vsock_get_matching_pcb(src, dst);
365 		if (pcb) {
366 			socket_lock_assert_owned(pcb->so);
367 			vsock_unbind_pcb(pcb);
368 			socket_unlock(pcb->so, 1);
369 		}
370 	} else {
371 		if (!transport) {
372 			transport = os_atomic_load(&the_vsock_transport, relaxed);
373 			if (transport == NULL) {
374 				return ENODEV;
375 			}
376 		}
377 		error = transport->put_message(transport->provider, src, dst, VSOCK_RESET, 0, 0, NULL);
378 	}
379 
380 	return error;
381 }
382 
383 static errno_t
vsock_pcb_safe_reset_address(struct vsockpcb * pcb,struct vsock_address src,struct vsock_address dst)384 vsock_pcb_safe_reset_address(struct vsockpcb *pcb, struct vsock_address src, struct vsock_address dst)
385 {
386 	if (pcb) {
387 		socket_lock_assert_owned(pcb->so);
388 		socket_unlock(pcb->so, 0);
389 	}
390 	errno_t error = vsock_pcb_reset_address(src, dst);
391 	if (pcb) {
392 		socket_lock(pcb->so, 0);
393 	}
394 	return error;
395 }
396 
397 static errno_t
vsock_pcb_connect(struct vsockpcb * pcb)398 vsock_pcb_connect(struct vsockpcb *pcb)
399 {
400 	return vsock_pcb_send_message(pcb, VSOCK_REQUEST, NULL);
401 }
402 
403 static errno_t
vsock_pcb_respond(struct vsockpcb * pcb)404 vsock_pcb_respond(struct vsockpcb *pcb)
405 {
406 	return vsock_pcb_send_message(pcb, VSOCK_RESPONSE, NULL);
407 }
408 
409 static errno_t
vsock_pcb_send(struct vsockpcb * pcb,mbuf_t m)410 vsock_pcb_send(struct vsockpcb *pcb, mbuf_t m)
411 {
412 	return vsock_pcb_send_message(pcb, VSOCK_PAYLOAD, m);
413 }
414 
415 static errno_t
vsock_pcb_shutdown_send(struct vsockpcb * pcb)416 vsock_pcb_shutdown_send(struct vsockpcb *pcb)
417 {
418 	return vsock_pcb_send_message(pcb, VSOCK_SHUTDOWN_SEND, NULL);
419 }
420 
421 static errno_t
vsock_pcb_reset(struct vsockpcb * pcb)422 vsock_pcb_reset(struct vsockpcb *pcb)
423 {
424 	return vsock_pcb_send_message(pcb, VSOCK_RESET, NULL);
425 }
426 
427 static errno_t
vsock_pcb_credit_update(struct vsockpcb * pcb)428 vsock_pcb_credit_update(struct vsockpcb *pcb)
429 {
430 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_UPDATE, NULL);
431 }
432 
433 static errno_t
vsock_pcb_credit_request(struct vsockpcb * pcb)434 vsock_pcb_credit_request(struct vsockpcb *pcb)
435 {
436 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_REQUEST, NULL);
437 }
438 
439 static errno_t
vsock_disconnect_pcb_common(struct vsockpcb * pcb,bool is_locked)440 vsock_disconnect_pcb_common(struct vsockpcb *pcb, bool is_locked)
441 {
442 	socket_lock_assert_owned(pcb->so);
443 	vsock_unbind_pcb_locked(pcb, is_locked);
444 	return vsock_pcb_reset(pcb);
445 }
446 
447 static errno_t
vsock_disconnect_pcb_locked(struct vsockpcb * pcb)448 vsock_disconnect_pcb_locked(struct vsockpcb *pcb)
449 {
450 	return vsock_disconnect_pcb_common(pcb, true);
451 }
452 
453 static errno_t
vsock_disconnect_pcb(struct vsockpcb * pcb)454 vsock_disconnect_pcb(struct vsockpcb *pcb)
455 {
456 	return vsock_disconnect_pcb_common(pcb, false);
457 }
458 
459 static errno_t
vsock_sockaddr_vm_validate(struct vsockpcb * pcb,struct sockaddr_vm * addr)460 vsock_sockaddr_vm_validate(struct vsockpcb *pcb, struct sockaddr_vm *addr)
461 {
462 	if (!pcb || !pcb->so || !addr) {
463 		return EINVAL;
464 	}
465 
466 	// Validate address length.
467 	if (addr->svm_len < sizeof(struct sockaddr_vm)) {
468 		return EINVAL;
469 	}
470 
471 	// Validate address family.
472 	if (addr->svm_family != AF_UNSPEC && addr->svm_family != AF_VSOCK) {
473 		return EAFNOSUPPORT;
474 	}
475 
476 	// Only stream is supported currently.
477 	if (pcb->so->so_type != SOCK_STREAM) {
478 		return EAFNOSUPPORT;
479 	}
480 
481 	return 0;
482 }
483 
484 /* VSock Receive Handlers */
485 
486 static errno_t
vsock_put_message_connected(struct vsockpcb * pcb,enum vsock_operation op,mbuf_t m)487 vsock_put_message_connected(struct vsockpcb *pcb, enum vsock_operation op, mbuf_t m)
488 {
489 	socket_lock_assert_owned(pcb->so);
490 
491 	errno_t error = 0;
492 
493 	switch (op) {
494 	case VSOCK_SHUTDOWN:
495 		socantsendmore(pcb->so);
496 		socantrcvmore(pcb->so);
497 		break;
498 	case VSOCK_SHUTDOWN_RECEIVE:
499 		socantsendmore(pcb->so);
500 		break;
501 	case VSOCK_SHUTDOWN_SEND:
502 		socantrcvmore(pcb->so);
503 		break;
504 	case VSOCK_PAYLOAD:
505 		// Add data to the receive queue then wakeup any reading threads.
506 		error = !sbappendstream(&pcb->so->so_rcv, m);
507 		if (!error) {
508 			sorwakeup(pcb->so);
509 		}
510 		break;
511 	case VSOCK_RESET:
512 		vsock_unbind_pcb(pcb);
513 		break;
514 	default:
515 		error = ENOTSUP;
516 		break;
517 	}
518 
519 	return error;
520 }
521 
522 static errno_t
vsock_put_message_connecting(struct vsockpcb * pcb,enum vsock_operation op)523 vsock_put_message_connecting(struct vsockpcb *pcb, enum vsock_operation op)
524 {
525 	socket_lock_assert_owned(pcb->so);
526 
527 	errno_t error = 0;
528 
529 	switch (op) {
530 	case VSOCK_RESPONSE:
531 		soisconnected(pcb->so);
532 		break;
533 	case VSOCK_RESET:
534 		pcb->so->so_error = EAGAIN;
535 		error = vsock_disconnect_pcb(pcb);
536 		break;
537 	default:
538 		vsock_disconnect_pcb(pcb);
539 		error = ENOTSUP;
540 		break;
541 	}
542 
543 	return error;
544 }
545 
546 static errno_t
vsock_put_message_listening(struct vsockpcb * pcb,enum vsock_operation op,struct vsock_address src,struct vsock_address dst)547 vsock_put_message_listening(struct vsockpcb *pcb, enum vsock_operation op, struct vsock_address src, struct vsock_address dst)
548 {
549 	socket_lock_assert_owned(pcb->so);
550 
551 	struct sockaddr_vm addr;
552 	struct socket *so2 = NULL;
553 	struct vsockpcb *pcb2 = NULL;
554 
555 	errno_t error = 0;
556 
557 	switch (op) {
558 	case VSOCK_REQUEST:
559 		addr = (struct sockaddr_vm) {
560 			.svm_len = sizeof(addr),
561 			.svm_family = AF_VSOCK,
562 			.svm_reserved1 = 0,
563 			.svm_port = pcb->local_address.port,
564 			.svm_cid = pcb->local_address.cid
565 		};
566 		so2 = sonewconn(pcb->so, 0, (struct sockaddr *)&addr);
567 		if (!so2) {
568 			// It is likely that the backlog is full. Deny this request.
569 			vsock_pcb_safe_reset_address(pcb, dst, src);
570 			error = ECONNREFUSED;
571 			break;
572 		}
573 
574 		pcb2 = sotovsockpcb(so2);
575 		if (!pcb2) {
576 			error = EINVAL;
577 			goto done;
578 		}
579 
580 		error = vsock_bind_address(pcb2, dst, src);
581 		if (error) {
582 			goto done;
583 		}
584 
585 		error = vsock_pcb_respond(pcb2);
586 		if (error) {
587 			goto done;
588 		}
589 
590 		soisconnected(so2);
591 
592 done:
593 		if (error) {
594 			if (pcb2) {
595 				vsock_unbind_pcb(pcb2);
596 			} else {
597 				soisdisconnected(so2);
598 			}
599 			socket_unlock(so2, 1);
600 			vsock_pcb_reset_address(dst, src);
601 		} else {
602 			socket_unlock(so2, 0);
603 		}
604 		socket_lock(pcb->so, 0);
605 
606 		break;
607 	case VSOCK_RESET:
608 		error = vsock_pcb_safe_reset_address(pcb, dst, src);
609 		break;
610 	default:
611 		vsock_pcb_safe_reset_address(pcb, dst, src);
612 		error = ENOTSUP;
613 		break;
614 	}
615 
616 	return error;
617 }
618 
619 /* VSock Transport */
620 
621 errno_t
vsock_add_transport(struct vsock_transport * transport)622 vsock_add_transport(struct vsock_transport *transport)
623 {
624 	if (transport == NULL || transport->provider == NULL) {
625 		return EINVAL;
626 	}
627 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport, NULL, transport, acq_rel)) {
628 		return EEXIST;
629 	}
630 	return 0;
631 }
632 
633 errno_t
vsock_remove_transport(struct vsock_transport * transport)634 vsock_remove_transport(struct vsock_transport *transport)
635 {
636 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport, transport, NULL, acq_rel)) {
637 		return ENODEV;
638 	}
639 	return 0;
640 }
641 
642 errno_t
vsock_reset_transport(struct vsock_transport * transport)643 vsock_reset_transport(struct vsock_transport *transport)
644 {
645 	if (transport == NULL) {
646 		return EINVAL;
647 	}
648 
649 	errno_t error = 0;
650 	struct vsockpcb *pcb = NULL;
651 	struct vsockpcb *tmp_pcb = NULL;
652 
653 	lck_rw_lock_exclusive(&vsockinfo.bound_lock);
654 	LIST_FOREACH_SAFE(pcb, &vsockinfo.bound, bound, tmp_pcb) {
655 		// Disconnect this transport's sockets. Listen and bind sockets must stay alive.
656 		socket_lock(pcb->so, 1);
657 		if (pcb->transport == transport && pcb->so->so_state & (SS_ISCONNECTED | SS_ISCONNECTING | SS_ISDISCONNECTING)) {
658 			errno_t dc_error = vsock_disconnect_pcb_locked(pcb);
659 			if (dc_error && !error) {
660 				error = dc_error;
661 			}
662 		}
663 		socket_unlock(pcb->so, 1);
664 	}
665 	lck_rw_done(&vsockinfo.bound_lock);
666 
667 	return error;
668 }
669 
670 errno_t
vsock_put_message(struct vsock_address src,struct vsock_address dst,enum vsock_operation op,uint32_t buf_alloc,uint32_t fwd_cnt,mbuf_t m)671 vsock_put_message(struct vsock_address src, struct vsock_address dst, enum vsock_operation op, uint32_t buf_alloc, uint32_t fwd_cnt, mbuf_t m)
672 {
673 	struct vsockpcb *pcb = vsock_get_matching_pcb(dst, src);
674 	if (!pcb) {
675 		if (op != VSOCK_RESET) {
676 			vsock_pcb_reset_address(dst, src);
677 		}
678 		if (m != NULL) {
679 			mbuf_freem_list(m);
680 		}
681 		return EINVAL;
682 	}
683 
684 	socket_lock_assert_owned(pcb->so);
685 
686 	struct socket *so = pcb->so;
687 	errno_t error = 0;
688 
689 	// Check if the peer's buffer has changed. Update our view of the peer's forwarded bytes.
690 	int buffers_changed = (pcb->peer_buf_alloc != buf_alloc) || (pcb->peer_fwd_cnt) != fwd_cnt;
691 	pcb->peer_buf_alloc = buf_alloc;
692 	pcb->peer_fwd_cnt = fwd_cnt;
693 
694 	// Peer's buffer has enough space for the next packet. Notify any threads waiting for space.
695 	if (buffers_changed && vsock_get_peer_space(pcb) >= pcb->waiting_send_size) {
696 		sowwakeup(so);
697 	}
698 
699 	switch (op) {
700 	case VSOCK_CREDIT_REQUEST:
701 		error = vsock_pcb_credit_update(pcb);
702 		break;
703 	case VSOCK_CREDIT_UPDATE:
704 		break;
705 	default:
706 		if (so->so_state & SS_ISCONNECTED) {
707 			error = vsock_put_message_connected(pcb, op, m);
708 			m = NULL;
709 		} else if (so->so_state & SS_ISCONNECTING) {
710 			error = vsock_put_message_connecting(pcb, op);
711 		} else if (so->so_options & SO_ACCEPTCONN) {
712 			error = vsock_put_message_listening(pcb, op, src, dst);
713 		} else {
714 			// Reset the connection for other states such as 'disconnecting'.
715 			error = vsock_disconnect_pcb(pcb);
716 			if (!error) {
717 				error = ENODEV;
718 			}
719 		}
720 		break;
721 	}
722 	socket_unlock(so, 1);
723 
724 	if (m != NULL) {
725 		mbuf_freem_list(m);
726 	}
727 
728 	return error;
729 }
730 
731 /* VSock Sysctl */
732 
733 static int
734 vsock_pcblist SYSCTL_HANDLER_ARGS
735 {
736 #pragma unused(oidp,arg2)
737 
738 	int error;
739 
740 	// Only stream is supported.
741 	if ((intptr_t)arg1 != SOCK_STREAM) {
742 		return EINVAL;
743 	}
744 
745 	// Get the generation count and the count of all vsock sockets.
746 	lck_rw_lock_shared(&vsockinfo.all_lock);
747 	uint64_t n = vsockinfo.all_pcb_count;
748 	vsock_gen_t gen_count = vsockinfo.vsock_gencnt;
749 	lck_rw_done(&vsockinfo.all_lock);
750 
751 	const size_t xpcb_len = sizeof(struct xvsockpcb);
752 	struct xvsockpgen xvg;
753 
754 	/*
755 	 * The process of preparing the PCB list is too time-consuming and
756 	 * resource-intensive to repeat twice on every request.
757 	 */
758 	if (req->oldptr == USER_ADDR_NULL) {
759 		req->oldidx = (size_t)(2 * sizeof(xvg) + (n + n / 8) * xpcb_len);
760 		return 0;
761 	}
762 
763 	if (req->newptr != USER_ADDR_NULL) {
764 		return EPERM;
765 	}
766 
767 	bzero(&xvg, sizeof(xvg));
768 	xvg.xvg_len = sizeof(xvg);
769 	xvg.xvg_count = n;
770 	xvg.xvg_gen = gen_count;
771 	xvg.xvg_sogen = so_gencnt;
772 	error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
773 	if (error) {
774 		return error;
775 	}
776 
777 	// Return if no sockets exist.
778 	if (n == 0) {
779 		return 0;
780 	}
781 
782 	lck_rw_lock_shared(&vsockinfo.all_lock);
783 
784 	n = 0;
785 	struct vsockpcb *pcb = NULL;
786 	TAILQ_FOREACH(pcb, &vsockinfo.all, all) {
787 		// Bail if there is not enough user buffer for this next socket.
788 		if (req->oldlen - req->oldidx - sizeof(xvg) < xpcb_len) {
789 			break;
790 		}
791 
792 		// Populate the socket structure.
793 		socket_lock(pcb->so, 1);
794 		if (pcb->vsock_gencnt <= gen_count) {
795 			struct xvsockpcb xpcb;
796 			bzero(&xpcb, xpcb_len);
797 			xpcb.xv_len = xpcb_len;
798 			xpcb.xv_vsockpp = (uint64_t)VM_KERNEL_ADDRHASH(pcb);
799 			xpcb.xvp_local_cid = pcb->local_address.cid;
800 			xpcb.xvp_local_port = pcb->local_address.port;
801 			xpcb.xvp_remote_cid = pcb->remote_address.cid;
802 			xpcb.xvp_remote_port = pcb->remote_address.port;
803 			xpcb.xvp_rxcnt = pcb->fwd_cnt;
804 			xpcb.xvp_txcnt = pcb->tx_cnt;
805 			xpcb.xvp_peer_rxhiwat = pcb->peer_buf_alloc;
806 			xpcb.xvp_peer_rxcnt = pcb->peer_fwd_cnt;
807 			xpcb.xvp_last_pid = pcb->so->last_pid;
808 			xpcb.xvp_gencnt = pcb->vsock_gencnt;
809 			if (pcb->so) {
810 				sotoxsocket(pcb->so, &xpcb.xv_socket);
811 			}
812 			socket_unlock(pcb->so, 1);
813 
814 			error = SYSCTL_OUT(req, &xpcb, xpcb_len);
815 			if (error != 0) {
816 				break;
817 			}
818 			n++;
819 		} else {
820 			socket_unlock(pcb->so, 1);
821 		}
822 	}
823 
824 	// Update the generation count to match the sockets being returned.
825 	gen_count = vsockinfo.vsock_gencnt;
826 
827 	lck_rw_done(&vsockinfo.all_lock);
828 
829 	if (!error) {
830 		/*
831 		 * Give the user an updated idea of our state.
832 		 * If the generation differs from what we told
833 		 * her before, she knows that something happened
834 		 * while we were processing this request, and it
835 		 * might be necessary to retry.
836 		 */
837 		bzero(&xvg, sizeof(xvg));
838 		xvg.xvg_len = sizeof(xvg);
839 		xvg.xvg_count = n;
840 		xvg.xvg_gen = gen_count;
841 		xvg.xvg_sogen = so_gencnt;
842 		error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
843 	}
844 
845 	return error;
846 }
847 
848 #ifdef SYSCTL_DECL
849 SYSCTL_NODE(_net, OID_AUTO, vsock, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock");
850 SYSCTL_UINT(_net_vsock, OID_AUTO, sendspace, CTLFLAG_RW | CTLFLAG_LOCKED,
851     &vsock_sendspace, 0, "Maximum outgoing vsock datagram size");
852 SYSCTL_UINT(_net_vsock, OID_AUTO, recvspace, CTLFLAG_RW | CTLFLAG_LOCKED,
853     &vsock_recvspace, 0, "Maximum incoming vsock datagram size");
854 SYSCTL_PROC(_net_vsock, OID_AUTO, pcblist,
855     CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
856     (caddr_t)(long)SOCK_STREAM, 0, vsock_pcblist, "S,xvsockpcb",
857     "List of active vsock sockets");
858 
859 SYSCTL_UINT(_net_vsock, OID_AUTO, pcbcount, CTLFLAG_RD | CTLFLAG_LOCKED,
860     (u_int *)&vsockinfo.all_pcb_count, 0, "");
861 #endif
862 
863 /* VSock Protocol */
864 
865 static int
vsock_attach(struct socket * so,int proto,struct proc * p)866 vsock_attach(struct socket *so, int proto, struct proc *p)
867 {
868 	#pragma unused(proto, p)
869 
870 	// Reserve send and receive buffers.
871 	errno_t error = soreserve(so, vsock_sendspace, vsock_recvspace);
872 	if (error) {
873 		return error;
874 	}
875 
876 	// Attach should only be run once per socket.
877 	struct vsockpcb *pcb = sotovsockpcb(so);
878 	if (pcb) {
879 		return EINVAL;
880 	}
881 
882 	// Get the transport for this socket.
883 	struct vsock_transport *transport = os_atomic_load(&the_vsock_transport, relaxed);
884 	if (transport == NULL) {
885 		return ENODEV;
886 	}
887 
888 	// Initialize the vsock protocol control block.
889 	pcb = zalloc_flags(vsockpcb_zone, Z_WAITOK | Z_ZERO | Z_NOFAIL);
890 	pcb->so = so;
891 	pcb->transport = transport;
892 	pcb->local_address = (struct vsock_address) {
893 		.cid = VMADDR_CID_ANY,
894 		.port = VMADDR_PORT_ANY
895 	};
896 	pcb->remote_address = (struct vsock_address) {
897 		.cid = VMADDR_CID_ANY,
898 		.port = VMADDR_PORT_ANY
899 	};
900 	so->so_pcb = pcb;
901 
902 	// Tell the transport that this socket has attached.
903 	error = transport->attach_socket(transport->provider);
904 	if (error) {
905 		zfree(vsockpcb_zone, pcb);
906 		so->so_pcb = NULL;
907 		return error;
908 	}
909 
910 	// Add to the list of all vsock sockets.
911 	lck_rw_lock_exclusive(&vsockinfo.all_lock);
912 	TAILQ_INSERT_TAIL(&vsockinfo.all, pcb, all);
913 	vsockinfo.all_pcb_count++;
914 	pcb->vsock_gencnt = ++vsockinfo.vsock_gencnt;
915 	lck_rw_done(&vsockinfo.all_lock);
916 
917 	return 0;
918 }
919 
920 static int
vsock_control(struct socket * so,u_long cmd,caddr_t __sized_by (IOCPARM_LEN (cmd))data,struct ifnet * ifp,struct proc * p)921 vsock_control(struct socket *so, u_long cmd, caddr_t __sized_by(IOCPARM_LEN(cmd)) data, struct ifnet *ifp, struct proc *p)
922 {
923 	#pragma unused(ifp)
924 
925 	VERIFY(so != NULL || p == kernproc);
926 
927 	if (cmd != IOCTL_VM_SOCKETS_GET_LOCAL_CID) {
928 		return EINVAL;
929 	}
930 
931 	struct vsock_transport *transport;
932 	if (so) {
933 		struct vsockpcb *pcb = sotovsockpcb(so);
934 		if (pcb == NULL) {
935 			return EINVAL;
936 		}
937 		transport = pcb->transport;
938 	} else {
939 		transport = os_atomic_load(&the_vsock_transport, relaxed);
940 	}
941 
942 	if (transport == NULL) {
943 		return ENODEV;
944 	}
945 
946 	uint32_t transport_cid;
947 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
948 	if (error) {
949 		return error;
950 	}
951 
952 	memcpy(data, &transport_cid, sizeof(transport_cid));
953 
954 	return 0;
955 }
956 
957 static int
vsock_detach(struct socket * so)958 vsock_detach(struct socket *so)
959 {
960 	struct vsockpcb *pcb = sotovsockpcb(so);
961 	if (pcb == NULL) {
962 		return EINVAL;
963 	}
964 
965 	vsock_unbind_pcb(pcb);
966 
967 	// Tell the transport that this socket has detached.
968 	struct vsock_transport *transport = pcb->transport;
969 	errno_t error = transport->detach_socket(transport->provider);
970 	if (error) {
971 		return error;
972 	}
973 
974 	// Mark this socket for deallocation.
975 	so->so_flags |= SOF_PCBCLEARING;
976 
977 	// Reorder locks.
978 	socket_unlock(so, 0);
979 	lck_rw_lock_exclusive(&vsockinfo.all_lock);
980 	socket_lock(so, 0);
981 
982 	// Remove from the list of all vsock sockets.
983 	TAILQ_REMOVE(&vsockinfo.all, pcb, all);
984 	pcb->all.tqe_next = NULL;
985 	pcb->all.tqe_prev = NULL;
986 	vsockinfo.all_pcb_count--;
987 	vsockinfo.vsock_gencnt++;
988 	lck_rw_done(&vsockinfo.all_lock);
989 
990 	return 0;
991 }
992 
993 static int
vsock_abort(struct socket * so)994 vsock_abort(struct socket *so)
995 {
996 	return vsock_detach(so);
997 }
998 
999 static int
vsock_bind(struct socket * so,struct sockaddr * nam,struct proc * p)1000 vsock_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
1001 {
1002 	#pragma unused(p)
1003 
1004 	struct vsockpcb *pcb = sotovsockpcb(so);
1005 	if (pcb == NULL) {
1006 		return EINVAL;
1007 	}
1008 
1009 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1010 
1011 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr);
1012 	if (error) {
1013 		return error;
1014 	}
1015 
1016 	struct vsock_address laddr = (struct vsock_address) {
1017 		.cid = addr->svm_cid,
1018 		.port = addr->svm_port,
1019 	};
1020 
1021 	struct vsock_address raddr = (struct vsock_address) {
1022 		.cid = VMADDR_CID_ANY,
1023 		.port = VMADDR_PORT_ANY,
1024 	};
1025 
1026 	error = vsock_bind_address(pcb, laddr, raddr);
1027 	if (error) {
1028 		return error;
1029 	}
1030 
1031 	return 0;
1032 }
1033 
1034 static int
vsock_listen(struct socket * so,struct proc * p)1035 vsock_listen(struct socket *so, struct proc *p)
1036 {
1037 	#pragma unused(p)
1038 
1039 	struct vsockpcb *pcb = sotovsockpcb(so);
1040 	if (pcb == NULL) {
1041 		return EINVAL;
1042 	}
1043 
1044 	// Only stream is supported currently.
1045 	if (so->so_type != SOCK_STREAM) {
1046 		return EAFNOSUPPORT;
1047 	}
1048 
1049 	struct vsock_address *addr = &pcb->local_address;
1050 
1051 	if (addr->port == VMADDR_CID_ANY) {
1052 		return EFAULT;
1053 	}
1054 
1055 	struct vsock_transport *transport = pcb->transport;
1056 	uint32_t transport_cid;
1057 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
1058 	if (error) {
1059 		return error;
1060 	}
1061 
1062 	// Can listen on the transport's cid or any.
1063 	if (addr->cid != transport_cid && addr->cid != VMADDR_CID_ANY) {
1064 		return EFAULT;
1065 	}
1066 
1067 	return 0;
1068 }
1069 
1070 static int
vsock_accept(struct socket * so,struct sockaddr ** nam)1071 vsock_accept(struct socket *so, struct sockaddr **nam)
1072 {
1073 	struct vsockpcb *pcb = sotovsockpcb(so);
1074 	if (pcb == NULL) {
1075 		return EINVAL;
1076 	}
1077 
1078 	// Do not accept disconnected sockets.
1079 	if (so->so_state & SS_ISDISCONNECTED) {
1080 		return ECONNABORTED;
1081 	}
1082 
1083 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1084 
1085 	return 0;
1086 }
1087 
1088 static int
vsock_connect(struct socket * so,struct sockaddr * nam,struct proc * p)1089 vsock_connect(struct socket *so, struct sockaddr *nam, struct proc *p)
1090 {
1091 	#pragma unused(p)
1092 
1093 	struct vsockpcb *pcb = sotovsockpcb(so);
1094 	if (pcb == NULL) {
1095 		return EINVAL;
1096 	}
1097 
1098 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1099 
1100 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr);
1101 	if (error) {
1102 		return error;
1103 	}
1104 
1105 	uint32_t transport_cid;
1106 	struct vsock_transport *transport = pcb->transport;
1107 	error = transport->get_cid(transport->provider, &transport_cid);
1108 	if (error) {
1109 		return error;
1110 	}
1111 
1112 	// Only supporting connections to the host, hypervisor, or self for now.
1113 	if (addr->svm_cid != VMADDR_CID_HOST &&
1114 	    addr->svm_cid != VMADDR_CID_HYPERVISOR &&
1115 	    addr->svm_cid != transport_cid) {
1116 		return EFAULT;
1117 	}
1118 
1119 	soisconnecting(so);
1120 
1121 	// Set the remote and local address.
1122 	struct vsock_address remote_addr = (struct vsock_address) {
1123 		.cid = addr->svm_cid,
1124 		.port = addr->svm_port,
1125 	};
1126 
1127 	struct vsock_address local_addr = (struct vsock_address) {
1128 		.cid = transport_cid,
1129 		.port = VMADDR_PORT_ANY,
1130 	};
1131 
1132 	// Bind to the address.
1133 	error = vsock_bind_address(pcb, local_addr, remote_addr);
1134 	if (error) {
1135 		goto cleanup;
1136 	}
1137 
1138 	// Attempt a connection using the socket's transport.
1139 	error = vsock_pcb_connect(pcb);
1140 	if (error) {
1141 		goto cleanup;
1142 	}
1143 
1144 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1145 		// Don't wait for peer's response if non-blocking.
1146 		if (so->so_state & SS_NBIO) {
1147 			goto done;
1148 		}
1149 
1150 		struct timespec ts = (struct timespec) {
1151 			.tv_sec = so->so_snd.sb_timeo.tv_sec,
1152 			.tv_nsec = so->so_snd.sb_timeo.tv_usec * 1000,
1153 		};
1154 
1155 		lck_mtx_t *mutex_held;
1156 		if (so->so_proto->pr_getlock != NULL) {
1157 			mutex_held = (*so->so_proto->pr_getlock)(so, PR_F_WILLUNLOCK);
1158 		} else {
1159 			mutex_held = so->so_proto->pr_domain->dom_mtx;
1160 		}
1161 
1162 		// Wait until we receive a response to the connect request.
1163 		error = msleep((caddr_t)&so->so_timeo, mutex_held, PSOCK | PCATCH, "vsock_connect", &ts);
1164 		if (error) {
1165 			if (error == EAGAIN) {
1166 				error = ETIMEDOUT;
1167 			}
1168 			goto cleanup;
1169 		}
1170 	}
1171 
1172 cleanup:
1173 	if (so->so_error && !error) {
1174 		error = so->so_error;
1175 		so->so_error = 0;
1176 	}
1177 	if (!error) {
1178 		error = !(so->so_state & SS_ISCONNECTED);
1179 	}
1180 	if (error) {
1181 		vsock_unbind_pcb(pcb);
1182 	}
1183 
1184 done:
1185 	return error;
1186 }
1187 
1188 static int
vsock_disconnect(struct socket * so)1189 vsock_disconnect(struct socket *so)
1190 {
1191 	struct vsockpcb *pcb = sotovsockpcb(so);
1192 	if (pcb == NULL) {
1193 		return EINVAL;
1194 	}
1195 
1196 	return vsock_disconnect_pcb(pcb);
1197 }
1198 
1199 static int
vsock_sockaddr(struct socket * so,struct sockaddr ** nam)1200 vsock_sockaddr(struct socket *so, struct sockaddr **nam)
1201 {
1202 	struct vsockpcb *pcb = sotovsockpcb(so);
1203 	if (pcb == NULL) {
1204 		return EINVAL;
1205 	}
1206 
1207 	*nam = vsock_new_sockaddr(&pcb->local_address);
1208 
1209 	return 0;
1210 }
1211 
1212 static int
vsock_peeraddr(struct socket * so,struct sockaddr ** nam)1213 vsock_peeraddr(struct socket *so, struct sockaddr **nam)
1214 {
1215 	struct vsockpcb *pcb = sotovsockpcb(so);
1216 	if (pcb == NULL) {
1217 		return EINVAL;
1218 	}
1219 
1220 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1221 
1222 	return 0;
1223 }
1224 
1225 static int
vsock_send(struct socket * so,int flags,struct mbuf * m,struct sockaddr * nam,struct mbuf * control,proc_t p)1226 vsock_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *nam, struct mbuf *control, proc_t p)
1227 {
1228 	#pragma unused(flags, nam, p)
1229 
1230 	errno_t error = 0;
1231 	struct vsockpcb *pcb = sotovsockpcb(so);
1232 	if (pcb == NULL || m == NULL) {
1233 		error = EINVAL;
1234 		goto out;
1235 	}
1236 
1237 	if (control != NULL) {
1238 		error = EOPNOTSUPP;
1239 		goto out;
1240 	}
1241 
1242 	// Ensure this socket is connected.
1243 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1244 		error = EPERM;
1245 		goto out;
1246 	}
1247 
1248 	// rdar://84098487 (SEED: Web: Virtio-socket sent data lost after 128KB)
1249 	// For writes larger than the default `sosendmaxchain` of 65536, vsock_send() is called multiple times per write().
1250 	// Only the first call to vsock_send() is passed a valid mbuf packet, while subsequent calls are not marked as a packet
1251 	// with a valid length. We should mark all mbufs as a packet and set the correct packet length so that the downstream
1252 	// socket transport layer can correctly generate physical segments.
1253 	if (!(mbuf_flags(m) & MBUF_PKTHDR)) {
1254 		if (!(mbuf_flags(m) & M_EXT)) {
1255 			struct mbuf *header = NULL;
1256 			MGETHDR(header, M_WAITOK, MT_HEADER);
1257 			if (header == NULL) {
1258 				error = ENOBUFS;
1259 				goto out;
1260 			}
1261 			header->m_next = m;
1262 			m = header;
1263 		} else {
1264 			mbuf_setflags(m, mbuf_flags(m) | MBUF_PKTHDR);
1265 		}
1266 
1267 		size_t len = 0;
1268 		struct mbuf *next = m;
1269 		while (next) {
1270 			len += mbuf_len(next);
1271 			next = mbuf_next(next);
1272 		}
1273 		mbuf_pkthdr_setlen(m, len);
1274 	}
1275 
1276 	const size_t len = mbuf_pkthdr_len(m);
1277 	uint32_t free_space = vsock_get_peer_space(pcb);
1278 
1279 	// Ensure the peer has enough space in their receive buffer.
1280 	while (len > free_space) {
1281 		// Record the number of free peer bytes necessary before we can send.
1282 		if (len > pcb->waiting_send_size) {
1283 			pcb->waiting_send_size = len;
1284 		}
1285 
1286 		// Send a credit request.
1287 		error = vsock_pcb_credit_request(pcb);
1288 		if (error) {
1289 			goto out;
1290 		}
1291 
1292 		// Check again in case free space was automatically updated in loopback case.
1293 		free_space = vsock_get_peer_space(pcb);
1294 		if (len <= free_space) {
1295 			pcb->waiting_send_size = 0;
1296 			break;
1297 		}
1298 
1299 		// Bail if this is a non-blocking socket.
1300 		if (so->so_state & SS_NBIO) {
1301 			error = EWOULDBLOCK;
1302 			goto out;
1303 		}
1304 
1305 		// Wait until our peer has enough free space in their receive buffer.
1306 		error = sbwait(&so->so_snd);
1307 		pcb->waiting_send_size = 0;
1308 		if (error) {
1309 			goto out;
1310 		}
1311 
1312 		// Bail if an error occured or we can't send more.
1313 		if (so->so_state & SS_CANTSENDMORE) {
1314 			error = EPIPE;
1315 			goto out;
1316 		} else if (so->so_error) {
1317 			error = so->so_error;
1318 			so->so_error = 0;
1319 			goto out;
1320 		}
1321 
1322 		free_space = vsock_get_peer_space(pcb);
1323 	}
1324 
1325 	// Send a payload over the transport.
1326 	error = vsock_pcb_send(pcb, m);
1327 	if (error) {
1328 		return error;
1329 	}
1330 
1331 	pcb->tx_cnt += len;
1332 
1333 	return 0;
1334 
1335 out:
1336 	if (control != NULL) {
1337 		m_freem(control);
1338 	}
1339 	if (m != NULL) {
1340 		mbuf_freem_list(m);
1341 	}
1342 	return error;
1343 }
1344 
1345 static int
vsock_shutdown(struct socket * so)1346 vsock_shutdown(struct socket *so)
1347 {
1348 	struct vsockpcb *pcb = sotovsockpcb(so);
1349 	if (pcb == NULL) {
1350 		return EINVAL;
1351 	}
1352 
1353 	socantsendmore(so);
1354 
1355 	// Tell peer we will no longer send.
1356 	errno_t error = vsock_pcb_shutdown_send(pcb);
1357 	if (error) {
1358 		return error;
1359 	}
1360 
1361 	return 0;
1362 }
1363 
1364 static int
vsock_soreceive(struct socket * so,struct sockaddr ** psa,struct uio * uio,struct mbuf ** mp0,struct mbuf ** controlp,int * flagsp)1365 vsock_soreceive(struct socket *so, struct sockaddr **psa, struct uio *uio,
1366     struct mbuf **mp0, struct mbuf **controlp, int *flagsp)
1367 {
1368 	struct vsockpcb *pcb = sotovsockpcb(so);
1369 	if (pcb == NULL) {
1370 		return EINVAL;
1371 	}
1372 
1373 	user_ssize_t length = uio_resid(uio);
1374 	int result = soreceive(so, psa, uio, mp0, controlp, flagsp);
1375 	length -= uio_resid(uio);
1376 
1377 	socket_lock(so, 1);
1378 
1379 	pcb->fwd_cnt += length;
1380 
1381 	const uint32_t threshold = VSOCK_MAX_PACKET_SIZE;
1382 
1383 	// Send a credit update if is possible that the peer will no longer send.
1384 	if ((pcb->fwd_cnt - pcb->last_fwd_cnt + threshold) >= pcb->last_buf_alloc) {
1385 		errno_t error = vsock_pcb_credit_update(pcb);
1386 		if (!result && error) {
1387 			result = error;
1388 		}
1389 	}
1390 
1391 	socket_unlock(so, 1);
1392 
1393 	return result;
1394 }
1395 
1396 static struct pr_usrreqs vsock_usrreqs = {
1397 	.pru_abort =            vsock_abort,
1398 	.pru_attach =           vsock_attach,
1399 	.pru_control =          vsock_control,
1400 	.pru_detach =           vsock_detach,
1401 	.pru_bind =             vsock_bind,
1402 	.pru_listen =           vsock_listen,
1403 	.pru_accept =           vsock_accept,
1404 	.pru_connect =          vsock_connect,
1405 	.pru_disconnect =       vsock_disconnect,
1406 	.pru_send =             vsock_send,
1407 	.pru_shutdown =         vsock_shutdown,
1408 	.pru_sockaddr =         vsock_sockaddr,
1409 	.pru_peeraddr =         vsock_peeraddr,
1410 	.pru_sosend =           sosend,
1411 	.pru_soreceive =        vsock_soreceive,
1412 };
1413 
1414 static void
vsock_init(struct protosw * pp,struct domain * dp)1415 vsock_init(struct protosw *pp, struct domain *dp)
1416 {
1417 	#pragma unused(dp)
1418 
1419 	static int vsock_initialized = 0;
1420 	VERIFY((pp->pr_flags & (PR_INITIALIZED | PR_ATTACHED)) == PR_ATTACHED);
1421 	if (!os_atomic_cmpxchg((volatile int *)&vsock_initialized, 0, 1, acq_rel)) {
1422 		return;
1423 	}
1424 
1425 	// Setup VSock protocol info struct.
1426 	lck_rw_init(&vsockinfo.all_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1427 	lck_rw_init(&vsockinfo.bound_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1428 	lck_mtx_init(&vsockinfo.port_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1429 	TAILQ_INIT(&vsockinfo.all);
1430 	LIST_INIT(&vsockinfo.bound);
1431 	vsockinfo.last_port = VMADDR_PORT_ANY;
1432 }
1433 
1434 static int
vsock_sofreelastref(struct socket * so,int dealloc)1435 vsock_sofreelastref(struct socket *so, int dealloc)
1436 {
1437 	socket_lock_assert_owned(so);
1438 
1439 	struct vsockpcb *pcb = sotovsockpcb(so);
1440 	if (pcb != NULL) {
1441 		zfree(vsockpcb_zone, pcb);
1442 	}
1443 
1444 	so->so_pcb = NULL;
1445 	sofreelastref(so, dealloc);
1446 
1447 	return 0;
1448 }
1449 
1450 static int
vsock_unlock(struct socket * so,int refcount,void * lr_saved)1451 vsock_unlock(struct socket *so, int refcount, void *lr_saved)
1452 {
1453 	lck_mtx_t *mutex_held = so->so_proto->pr_domain->dom_mtx;
1454 #ifdef MORE_LOCKING_DEBUG
1455 	LCK_MTX_ASSERT(mutex_held, LCK_MTX_ASSERT_OWNED);
1456 #endif
1457 	so->unlock_lr[so->next_unlock_lr] = lr_saved;
1458 	so->next_unlock_lr = (so->next_unlock_lr + 1) % SO_LCKDBG_MAX;
1459 
1460 	if (refcount) {
1461 		if (so->so_usecount <= 0) {
1462 			panic("%s: bad refcount=%d so=%p (%d, %d, %d) "
1463 			    "lrh=%s", __func__, so->so_usecount, so,
1464 			    SOCK_DOM(so), so->so_type,
1465 			    SOCK_PROTO(so), solockhistory_nr(so));
1466 			/* NOTREACHED */
1467 		}
1468 
1469 		so->so_usecount--;
1470 		if (so->so_usecount == 0) {
1471 			vsock_sofreelastref(so, 1);
1472 		}
1473 	}
1474 	lck_mtx_unlock(mutex_held);
1475 
1476 	return 0;
1477 }
1478 
1479 static struct protosw vsocksw[] = {
1480 	{
1481 		.pr_type =              SOCK_STREAM,
1482 		.pr_protocol =          0,
1483 		.pr_flags =             PR_CONNREQUIRED | PR_WANTRCVD,
1484 		.pr_init =              vsock_init,
1485 		.pr_unlock =            vsock_unlock,
1486 		.pr_usrreqs =           &vsock_usrreqs,
1487 	}
1488 };
1489 
1490 static const int vsock_proto_count = (sizeof(vsocksw) / sizeof(struct protosw));
1491 
1492 /* VSock Domain */
1493 
1494 static struct domain *vsock_domain = NULL;
1495 
1496 static void
vsock_dinit(struct domain * dp)1497 vsock_dinit(struct domain *dp)
1498 {
1499 	// The VSock domain is initialized with a singleton pattern.
1500 	VERIFY(!(dp->dom_flags & DOM_INITIALIZED));
1501 	VERIFY(vsock_domain == NULL);
1502 	vsock_domain = dp;
1503 
1504 	// Add protocols and initialize.
1505 	for (int i = 0; i < vsock_proto_count; i++) {
1506 		net_add_proto((struct protosw *)&vsocksw[i], dp, 1);
1507 	}
1508 }
1509 
1510 struct domain vsockdomain_s = {
1511 	.dom_family =           PF_VSOCK,
1512 	.dom_name =             "vsock",
1513 	.dom_init =             vsock_dinit,
1514 	.dom_maxrtkey =         sizeof(struct sockaddr_vm),
1515 	.dom_protohdrlen =      sizeof(struct sockaddr_vm),
1516 };
1517