xref: /xnu-10063.141.1/bsd/kern/vsock_domain.c (revision d8b80295118ef25ac3a784134bcf95cd8e88109f)
1 /*
2  * Copyright (c) 2020 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <sys/domain.h>
30 #include <sys/socket.h>
31 #include <sys/protosw.h>
32 #include <sys/mcache.h>
33 #include <sys/systm.h>
34 #include <sys/sysctl.h>
35 #include <sys/random.h>
36 #include <sys/mbuf.h>
37 #include <sys/vsock_domain.h>
38 #include <sys/vsock_transport.h>
39 #include <kern/task.h>
40 #include <kern/zalloc.h>
41 #include <kern/locks.h>
42 #include <machine/atomic.h>
43 
44 #define sotovsockpcb(so) ((struct vsockpcb *)(so)->so_pcb)
45 
46 #define VSOCK_PORT_RESERVED 1024
47 
48 /* VSock Protocol Globals */
49 
50 static struct vsock_transport * _Atomic the_vsock_transport = NULL;
51 static ZONE_DEFINE(vsockpcb_zone, "vsockpcbzone",
52     sizeof(struct vsockpcb), ZC_NONE);
53 static LCK_GRP_DECLARE(vsock_lock_grp, "vsock");
54 static struct vsockpcbinfo vsockinfo;
55 
56 static uint32_t vsock_sendspace = VSOCK_MAX_PACKET_SIZE * 8;
57 static uint32_t vsock_recvspace = VSOCK_MAX_PACKET_SIZE * 8;
58 
59 /* VSock PCB Helpers */
60 
61 static uint32_t
vsock_get_peer_space(struct vsockpcb * pcb)62 vsock_get_peer_space(struct vsockpcb *pcb)
63 {
64 	return pcb->peer_buf_alloc - (pcb->tx_cnt - pcb->peer_fwd_cnt);
65 }
66 
67 static struct vsockpcb *
vsock_get_matching_pcb(struct vsock_address src,struct vsock_address dst)68 vsock_get_matching_pcb(struct vsock_address src, struct vsock_address dst)
69 {
70 	struct vsockpcb *preferred = NULL;
71 	struct vsockpcb *match = NULL;
72 	struct vsockpcb *pcb = NULL;
73 
74 	lck_rw_lock_shared(&vsockinfo.bound_lock);
75 	LIST_FOREACH(pcb, &vsockinfo.bound, bound) {
76 		// Source cid and port must match. Only destination port must match. (Allows for a changing CID during migration)
77 		socket_lock(pcb->so, 1);
78 		if ((pcb->so->so_state & SS_ISCONNECTED || pcb->so->so_state & SS_ISCONNECTING) &&
79 		    pcb->local_address.cid == src.cid && pcb->local_address.port == src.port &&
80 		    pcb->remote_address.port == dst.port) {
81 			preferred = pcb;
82 			break;
83 		} else if ((pcb->local_address.cid == src.cid || pcb->local_address.cid == VMADDR_CID_ANY) &&
84 		    pcb->local_address.port == src.port) {
85 			match = pcb;
86 		}
87 		socket_unlock(pcb->so, 1);
88 	}
89 	if (!preferred && match) {
90 		socket_lock(match->so, 1);
91 		preferred = match;
92 	}
93 	lck_rw_done(&vsockinfo.bound_lock);
94 
95 	return preferred;
96 }
97 
98 static errno_t
vsock_bind_address_if_free(struct vsockpcb * pcb,uint32_t local_cid,uint32_t local_port,uint32_t remote_cid,uint32_t remote_port)99 vsock_bind_address_if_free(struct vsockpcb *pcb, uint32_t local_cid, uint32_t local_port, uint32_t remote_cid, uint32_t remote_port)
100 {
101 	socket_lock_assert_owned(pcb->so);
102 
103 	// Privileged ports.
104 	if (local_port != VMADDR_PORT_ANY && local_port < VSOCK_PORT_RESERVED &&
105 	    current_task() != kernel_task && proc_suser(current_proc()) != 0) {
106 		return EACCES;
107 	}
108 
109 	bool taken = false;
110 	const bool check_remote = (remote_cid != VMADDR_CID_ANY && remote_port != VMADDR_PORT_ANY);
111 
112 	struct vsockpcb *pcb_match = NULL;
113 
114 	socket_unlock(pcb->so, 0);
115 	lck_rw_lock_exclusive(&vsockinfo.bound_lock);
116 	LIST_FOREACH(pcb_match, &vsockinfo.bound, bound) {
117 		socket_lock(pcb_match->so, 1);
118 		if (pcb == pcb_match ||
119 		    (!check_remote && pcb_match->local_address.port == local_port) ||
120 		    (check_remote && pcb_match->local_address.port == local_port &&
121 		    pcb_match->remote_address.cid == remote_cid && pcb_match->remote_address.port == remote_port)) {
122 			socket_unlock(pcb_match->so, 1);
123 			taken = true;
124 			break;
125 		}
126 		socket_unlock(pcb_match->so, 1);
127 	}
128 	socket_lock(pcb->so, 0);
129 	if (!taken) {
130 		pcb->local_address = (struct vsock_address) { .cid = local_cid, .port = local_port };
131 		pcb->remote_address = (struct vsock_address) { .cid = remote_cid, .port = remote_port };
132 		LIST_INSERT_HEAD(&vsockinfo.bound, pcb, bound);
133 	}
134 	lck_rw_done(&vsockinfo.bound_lock);
135 
136 	return taken ? EADDRINUSE : 0;
137 }
138 
139 static errno_t
vsock_bind_address(struct vsockpcb * pcb,struct vsock_address laddr,struct vsock_address raddr)140 vsock_bind_address(struct vsockpcb *pcb, struct vsock_address laddr, struct vsock_address raddr)
141 {
142 	if (!pcb) {
143 		return EINVAL;
144 	}
145 
146 	socket_lock_assert_owned(pcb->so);
147 
148 	// Certain CIDs are reserved.
149 	if (laddr.cid == VMADDR_CID_HYPERVISOR || laddr.cid == VMADDR_CID_RESERVED || laddr.cid == VMADDR_CID_HOST) {
150 		return EADDRNOTAVAIL;
151 	}
152 
153 	// Remote address must be fully specified or not specified at all.
154 	if ((raddr.cid == VMADDR_CID_ANY) ^ (raddr.port == VMADDR_PORT_ANY)) {
155 		return EINVAL;
156 	}
157 
158 	// Cannot bind if already bound.
159 	if (pcb->local_address.port != VMADDR_PORT_ANY) {
160 		return EINVAL;
161 	}
162 
163 	uint32_t transport_cid;
164 	struct vsock_transport *transport = pcb->transport;
165 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
166 	if (error) {
167 		return error;
168 	}
169 
170 	// Local CID must be this transport's CID or any.
171 	if (laddr.cid != transport_cid && laddr.cid != VMADDR_CID_ANY) {
172 		return EINVAL;
173 	}
174 
175 	if (laddr.port != VMADDR_PORT_ANY) {
176 		error = vsock_bind_address_if_free(pcb, laddr.cid, laddr.port, raddr.cid, raddr.port);
177 	} else {
178 		socket_unlock(pcb->so, 0);
179 		lck_mtx_lock(&vsockinfo.port_lock);
180 		socket_lock(pcb->so, 0);
181 
182 		const uint32_t first = VSOCK_PORT_RESERVED;
183 		const uint32_t last = VMADDR_PORT_ANY - 1;
184 		uint32_t count = last - first + 1;
185 		uint32_t *last_port = &vsockinfo.last_port;
186 
187 		if (pcb->so->so_flags & SOF_BINDRANDOMPORT) {
188 			uint32_t random = 0;
189 			read_frandom(&random, sizeof(random));
190 			*last_port = first + (random % count);
191 		}
192 
193 		do {
194 			if (count == 0) {
195 				lck_mtx_unlock(&vsockinfo.port_lock);
196 				return EADDRNOTAVAIL;
197 			}
198 			count--;
199 
200 			++*last_port;
201 			if (*last_port < first || *last_port > last) {
202 				*last_port = first;
203 			}
204 
205 			error = vsock_bind_address_if_free(pcb, laddr.cid, *last_port, raddr.cid, raddr.port);
206 		} while (error);
207 
208 		lck_mtx_unlock(&vsockinfo.port_lock);
209 	}
210 
211 	return error;
212 }
213 
214 static void
vsock_unbind_pcb_locked(struct vsockpcb * pcb,bool is_locked)215 vsock_unbind_pcb_locked(struct vsockpcb *pcb, bool is_locked)
216 {
217 	if (!pcb) {
218 		return;
219 	}
220 
221 	struct socket *so = pcb->so;
222 	socket_lock_assert_owned(so);
223 
224 	// Bail if disconnect and already unbound.
225 	if (so->so_state & SS_ISDISCONNECTED) {
226 		assert(pcb->bound.le_next == NULL);
227 		assert(pcb->bound.le_prev == NULL);
228 		return;
229 	}
230 
231 	if (!is_locked) {
232 		socket_unlock(so, 0);
233 		lck_rw_lock_exclusive(&vsockinfo.bound_lock);
234 		socket_lock(so, 0);
235 
236 		// Case where some other thread also called unbind() on this socket while waiting to acquire its lock.
237 		if (!pcb->bound.le_prev) {
238 			soisdisconnected(so);
239 			lck_rw_done(&vsockinfo.bound_lock);
240 			return;
241 		}
242 	}
243 
244 	soisdisconnected(so);
245 
246 	LIST_REMOVE(pcb, bound);
247 	pcb->bound.le_next = NULL;
248 	pcb->bound.le_prev = NULL;
249 
250 	if (!is_locked) {
251 		lck_rw_done(&vsockinfo.bound_lock);
252 	}
253 }
254 
255 static void
vsock_unbind_pcb(struct vsockpcb * pcb)256 vsock_unbind_pcb(struct vsockpcb *pcb)
257 {
258 	vsock_unbind_pcb_locked(pcb, false);
259 }
260 
261 static struct sockaddr *
vsock_new_sockaddr(struct vsock_address * address)262 vsock_new_sockaddr(struct vsock_address *address)
263 {
264 	if (!address) {
265 		return NULL;
266 	}
267 
268 	struct sockaddr_vm *addr;
269 	addr = (struct sockaddr_vm *)alloc_sockaddr(sizeof(*addr),
270 	    Z_WAITOK | Z_NOFAIL);
271 
272 	addr->svm_family = AF_VSOCK;
273 	addr->svm_port = address->port;
274 	addr->svm_cid = address->cid;
275 
276 	return (struct sockaddr *)addr;
277 }
278 
279 static errno_t
vsock_pcb_send_message(struct vsockpcb * pcb,enum vsock_operation operation,mbuf_t m)280 vsock_pcb_send_message(struct vsockpcb *pcb, enum vsock_operation operation, mbuf_t m)
281 {
282 	if (!pcb) {
283 		if (m != NULL) {
284 			mbuf_freem_list(m);
285 		}
286 		return EINVAL;
287 	}
288 
289 	socket_lock_assert_owned(pcb->so);
290 
291 	errno_t error;
292 
293 	struct vsock_address dst = pcb->remote_address;
294 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
295 		if (m != NULL) {
296 			mbuf_freem_list(m);
297 		}
298 		return EINVAL;
299 	}
300 
301 	struct vsock_address src = pcb->local_address;
302 	if (src.cid == VMADDR_CID_ANY) {
303 		uint32_t transport_cid;
304 		struct vsock_transport *transport = pcb->transport;
305 		error = transport->get_cid(transport->provider, &transport_cid);
306 		if (error) {
307 			if (m != NULL) {
308 				mbuf_freem_list(m);
309 			}
310 			return error;
311 		}
312 		src.cid = transport_cid;
313 	}
314 
315 	uint32_t buf_alloc = pcb->so->so_rcv.sb_hiwat;
316 	uint32_t fwd_cnt = pcb->fwd_cnt;
317 
318 	if (src.cid == dst.cid) {
319 		pcb->last_buf_alloc = buf_alloc;
320 		pcb->last_fwd_cnt = fwd_cnt;
321 
322 		socket_unlock(pcb->so, 0);
323 		error = vsock_put_message(src, dst, operation, buf_alloc, fwd_cnt, m);
324 		socket_lock(pcb->so, 0);
325 	} else {
326 		struct vsock_transport *transport = pcb->transport;
327 		error = transport->put_message(transport->provider, src, dst, operation, buf_alloc, fwd_cnt, m);
328 
329 		if (!error) {
330 			pcb->last_buf_alloc = buf_alloc;
331 			pcb->last_fwd_cnt = fwd_cnt;
332 		}
333 	}
334 
335 	return error;
336 }
337 
338 static errno_t
vsock_pcb_reset_address(struct vsock_address src,struct vsock_address dst)339 vsock_pcb_reset_address(struct vsock_address src, struct vsock_address dst)
340 {
341 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
342 		return EINVAL;
343 	}
344 
345 	errno_t error = 0;
346 	struct vsock_transport *transport = NULL;
347 
348 	if (src.cid == VMADDR_CID_ANY) {
349 		transport = os_atomic_load(&the_vsock_transport, relaxed);
350 		if (transport == NULL) {
351 			return ENODEV;
352 		}
353 
354 		uint32_t transport_cid;
355 		error = transport->get_cid(transport->provider, &transport_cid);
356 		if (error) {
357 			return error;
358 		}
359 		src.cid = transport_cid;
360 	}
361 
362 	if (src.cid == dst.cid) {
363 		// Reset both sockets.
364 		struct vsockpcb *pcb = vsock_get_matching_pcb(src, dst);
365 		if (pcb) {
366 			socket_lock_assert_owned(pcb->so);
367 			vsock_unbind_pcb(pcb);
368 			socket_unlock(pcb->so, 1);
369 		}
370 	} else {
371 		if (!transport) {
372 			transport = os_atomic_load(&the_vsock_transport, relaxed);
373 			if (transport == NULL) {
374 				return ENODEV;
375 			}
376 		}
377 		error = transport->put_message(transport->provider, src, dst, VSOCK_RESET, 0, 0, NULL);
378 	}
379 
380 	return error;
381 }
382 
383 static errno_t
vsock_pcb_safe_reset_address(struct vsockpcb * pcb,struct vsock_address src,struct vsock_address dst)384 vsock_pcb_safe_reset_address(struct vsockpcb *pcb, struct vsock_address src, struct vsock_address dst)
385 {
386 	if (pcb) {
387 		socket_lock_assert_owned(pcb->so);
388 		socket_unlock(pcb->so, 0);
389 	}
390 	errno_t error = vsock_pcb_reset_address(src, dst);
391 	if (pcb) {
392 		socket_lock(pcb->so, 0);
393 	}
394 	return error;
395 }
396 
397 static errno_t
vsock_pcb_connect(struct vsockpcb * pcb)398 vsock_pcb_connect(struct vsockpcb *pcb)
399 {
400 	return vsock_pcb_send_message(pcb, VSOCK_REQUEST, NULL);
401 }
402 
403 static errno_t
vsock_pcb_respond(struct vsockpcb * pcb)404 vsock_pcb_respond(struct vsockpcb *pcb)
405 {
406 	return vsock_pcb_send_message(pcb, VSOCK_RESPONSE, NULL);
407 }
408 
409 static errno_t
vsock_pcb_send(struct vsockpcb * pcb,mbuf_t m)410 vsock_pcb_send(struct vsockpcb *pcb, mbuf_t m)
411 {
412 	return vsock_pcb_send_message(pcb, VSOCK_PAYLOAD, m);
413 }
414 
415 static errno_t
vsock_pcb_shutdown_send(struct vsockpcb * pcb)416 vsock_pcb_shutdown_send(struct vsockpcb *pcb)
417 {
418 	return vsock_pcb_send_message(pcb, VSOCK_SHUTDOWN_SEND, NULL);
419 }
420 
421 static errno_t
vsock_pcb_reset(struct vsockpcb * pcb)422 vsock_pcb_reset(struct vsockpcb *pcb)
423 {
424 	return vsock_pcb_send_message(pcb, VSOCK_RESET, NULL);
425 }
426 
427 static errno_t
vsock_pcb_credit_update(struct vsockpcb * pcb)428 vsock_pcb_credit_update(struct vsockpcb *pcb)
429 {
430 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_UPDATE, NULL);
431 }
432 
433 static errno_t
vsock_pcb_credit_request(struct vsockpcb * pcb)434 vsock_pcb_credit_request(struct vsockpcb *pcb)
435 {
436 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_REQUEST, NULL);
437 }
438 
439 static errno_t
vsock_disconnect_pcb_common(struct vsockpcb * pcb,bool is_locked)440 vsock_disconnect_pcb_common(struct vsockpcb *pcb, bool is_locked)
441 {
442 	socket_lock_assert_owned(pcb->so);
443 	vsock_unbind_pcb_locked(pcb, is_locked);
444 	return vsock_pcb_reset(pcb);
445 }
446 
447 static errno_t
vsock_disconnect_pcb_locked(struct vsockpcb * pcb)448 vsock_disconnect_pcb_locked(struct vsockpcb *pcb)
449 {
450 	return vsock_disconnect_pcb_common(pcb, true);
451 }
452 
453 static errno_t
vsock_disconnect_pcb(struct vsockpcb * pcb)454 vsock_disconnect_pcb(struct vsockpcb *pcb)
455 {
456 	return vsock_disconnect_pcb_common(pcb, false);
457 }
458 
459 static errno_t
vsock_sockaddr_vm_validate(struct vsockpcb * pcb,struct sockaddr_vm * addr)460 vsock_sockaddr_vm_validate(struct vsockpcb *pcb, struct sockaddr_vm *addr)
461 {
462 	if (!pcb || !pcb->so || !addr) {
463 		return EINVAL;
464 	}
465 
466 	// Validate address length.
467 	if (addr->svm_len < sizeof(struct sockaddr_vm)) {
468 		return EINVAL;
469 	}
470 
471 	// Validate address family.
472 	if (addr->svm_family != AF_UNSPEC && addr->svm_family != AF_VSOCK) {
473 		return EAFNOSUPPORT;
474 	}
475 
476 	// Only stream is supported currently.
477 	if (pcb->so->so_type != SOCK_STREAM) {
478 		return EAFNOSUPPORT;
479 	}
480 
481 	return 0;
482 }
483 
484 /* VSock Receive Handlers */
485 
486 static errno_t
vsock_put_message_connected(struct vsockpcb * pcb,enum vsock_operation op,mbuf_t m)487 vsock_put_message_connected(struct vsockpcb *pcb, enum vsock_operation op, mbuf_t m)
488 {
489 	socket_lock_assert_owned(pcb->so);
490 
491 	errno_t error = 0;
492 
493 	switch (op) {
494 	case VSOCK_SHUTDOWN:
495 		socantsendmore(pcb->so);
496 		socantrcvmore(pcb->so);
497 		break;
498 	case VSOCK_SHUTDOWN_RECEIVE:
499 		socantsendmore(pcb->so);
500 		break;
501 	case VSOCK_SHUTDOWN_SEND:
502 		socantrcvmore(pcb->so);
503 		break;
504 	case VSOCK_PAYLOAD:
505 		// Add data to the receive queue then wakeup any reading threads.
506 		error = !sbappendstream(&pcb->so->so_rcv, m);
507 		if (!error) {
508 			sorwakeup(pcb->so);
509 		}
510 		break;
511 	case VSOCK_RESET:
512 		vsock_unbind_pcb(pcb);
513 		break;
514 	default:
515 		error = ENOTSUP;
516 		break;
517 	}
518 
519 	return error;
520 }
521 
522 static errno_t
vsock_put_message_connecting(struct vsockpcb * pcb,enum vsock_operation op)523 vsock_put_message_connecting(struct vsockpcb *pcb, enum vsock_operation op)
524 {
525 	socket_lock_assert_owned(pcb->so);
526 
527 	errno_t error = 0;
528 
529 	switch (op) {
530 	case VSOCK_RESPONSE:
531 		soisconnected(pcb->so);
532 		break;
533 	case VSOCK_RESET:
534 		pcb->so->so_error = EAGAIN;
535 		error = vsock_disconnect_pcb(pcb);
536 		break;
537 	default:
538 		vsock_disconnect_pcb(pcb);
539 		error = ENOTSUP;
540 		break;
541 	}
542 
543 	return error;
544 }
545 
546 static errno_t
vsock_put_message_listening(struct vsockpcb * pcb,enum vsock_operation op,struct vsock_address src,struct vsock_address dst)547 vsock_put_message_listening(struct vsockpcb *pcb, enum vsock_operation op, struct vsock_address src, struct vsock_address dst)
548 {
549 	socket_lock_assert_owned(pcb->so);
550 
551 	struct sockaddr_vm addr;
552 	struct socket *so2 = NULL;
553 	struct vsockpcb *pcb2 = NULL;
554 
555 	errno_t error = 0;
556 
557 	switch (op) {
558 	case VSOCK_REQUEST:
559 		addr = (struct sockaddr_vm) {
560 			.svm_len = sizeof(addr),
561 			.svm_family = AF_VSOCK,
562 			.svm_reserved1 = 0,
563 			.svm_port = pcb->local_address.port,
564 			.svm_cid = pcb->local_address.cid
565 		};
566 		so2 = sonewconn(pcb->so, 0, (struct sockaddr *)&addr);
567 		if (!so2) {
568 			// It is likely that the backlog is full. Deny this request.
569 			vsock_pcb_safe_reset_address(pcb, dst, src);
570 			error = ECONNREFUSED;
571 			break;
572 		}
573 
574 		pcb2 = sotovsockpcb(so2);
575 		if (!pcb2) {
576 			error = EINVAL;
577 			goto done;
578 		}
579 
580 		error = vsock_bind_address(pcb2, dst, src);
581 		if (error) {
582 			goto done;
583 		}
584 
585 		error = vsock_pcb_respond(pcb2);
586 		if (error) {
587 			goto done;
588 		}
589 
590 		soisconnected(so2);
591 
592 done:
593 		if (error) {
594 			if (pcb2) {
595 				vsock_unbind_pcb(pcb2);
596 			} else {
597 				soisdisconnected(so2);
598 			}
599 			socket_unlock(so2, 1);
600 			vsock_pcb_reset_address(dst, src);
601 		} else {
602 			socket_unlock(so2, 0);
603 		}
604 		socket_lock(pcb->so, 0);
605 
606 		break;
607 	case VSOCK_RESET:
608 		error = vsock_pcb_safe_reset_address(pcb, dst, src);
609 		break;
610 	default:
611 		vsock_pcb_safe_reset_address(pcb, dst, src);
612 		error = ENOTSUP;
613 		break;
614 	}
615 
616 	return error;
617 }
618 
619 /* VSock Transport */
620 
621 errno_t
vsock_add_transport(struct vsock_transport * transport)622 vsock_add_transport(struct vsock_transport *transport)
623 {
624 	if (transport == NULL || transport->provider == NULL) {
625 		return EINVAL;
626 	}
627 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport, NULL, transport, acq_rel)) {
628 		return EEXIST;
629 	}
630 	return 0;
631 }
632 
633 errno_t
vsock_remove_transport(struct vsock_transport * transport)634 vsock_remove_transport(struct vsock_transport *transport)
635 {
636 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport, transport, NULL, acq_rel)) {
637 		return ENODEV;
638 	}
639 	return 0;
640 }
641 
642 errno_t
vsock_reset_transport(struct vsock_transport * transport)643 vsock_reset_transport(struct vsock_transport *transport)
644 {
645 	if (transport == NULL) {
646 		return EINVAL;
647 	}
648 
649 	errno_t error = 0;
650 	struct vsockpcb *pcb = NULL;
651 	struct vsockpcb *tmp_pcb = NULL;
652 
653 	lck_rw_lock_exclusive(&vsockinfo.bound_lock);
654 	LIST_FOREACH_SAFE(pcb, &vsockinfo.bound, bound, tmp_pcb) {
655 		// Disconnect this transport's sockets. Listen and bind sockets must stay alive.
656 		socket_lock(pcb->so, 1);
657 		if (pcb->transport == transport && pcb->so->so_state & (SS_ISCONNECTED | SS_ISCONNECTING | SS_ISDISCONNECTING)) {
658 			errno_t dc_error = vsock_disconnect_pcb_locked(pcb);
659 			if (dc_error && !error) {
660 				error = dc_error;
661 			}
662 		}
663 		socket_unlock(pcb->so, 1);
664 	}
665 	lck_rw_done(&vsockinfo.bound_lock);
666 
667 	return error;
668 }
669 
670 errno_t
vsock_put_message(struct vsock_address src,struct vsock_address dst,enum vsock_operation op,uint32_t buf_alloc,uint32_t fwd_cnt,mbuf_t m)671 vsock_put_message(struct vsock_address src, struct vsock_address dst, enum vsock_operation op, uint32_t buf_alloc, uint32_t fwd_cnt, mbuf_t m)
672 {
673 	struct vsockpcb *pcb = vsock_get_matching_pcb(dst, src);
674 	if (!pcb) {
675 		if (op != VSOCK_RESET) {
676 			vsock_pcb_reset_address(dst, src);
677 		}
678 		if (m != NULL) {
679 			mbuf_freem_list(m);
680 		}
681 		return EINVAL;
682 	}
683 
684 	socket_lock_assert_owned(pcb->so);
685 
686 	struct socket *so = pcb->so;
687 	errno_t error = 0;
688 
689 	// Check if the peer's buffer has changed. Update our view of the peer's forwarded bytes.
690 	int buffers_changed = (pcb->peer_buf_alloc != buf_alloc) || (pcb->peer_fwd_cnt) != fwd_cnt;
691 	pcb->peer_buf_alloc = buf_alloc;
692 	pcb->peer_fwd_cnt = fwd_cnt;
693 
694 	// Peer's buffer has enough space for the next packet. Notify any threads waiting for space.
695 	if (buffers_changed && vsock_get_peer_space(pcb) >= pcb->waiting_send_size) {
696 		sowwakeup(so);
697 	}
698 
699 	switch (op) {
700 	case VSOCK_CREDIT_REQUEST:
701 		error = vsock_pcb_credit_update(pcb);
702 		break;
703 	case VSOCK_CREDIT_UPDATE:
704 		break;
705 	default:
706 		if (so->so_state & SS_ISCONNECTED) {
707 			error = vsock_put_message_connected(pcb, op, m);
708 			m = NULL;
709 		} else if (so->so_state & SS_ISCONNECTING) {
710 			error = vsock_put_message_connecting(pcb, op);
711 		} else if (so->so_options & SO_ACCEPTCONN) {
712 			error = vsock_put_message_listening(pcb, op, src, dst);
713 		} else {
714 			// Reset the connection for other states such as 'disconnecting'.
715 			error = vsock_disconnect_pcb(pcb);
716 			if (!error) {
717 				error = ENODEV;
718 			}
719 		}
720 		break;
721 	}
722 	socket_unlock(so, 1);
723 
724 	if (m != NULL) {
725 		mbuf_freem_list(m);
726 	}
727 
728 	return error;
729 }
730 
731 /* VSock Sysctl */
732 
733 static int
734 vsock_pcblist SYSCTL_HANDLER_ARGS
735 {
736 #pragma unused(oidp,arg2)
737 
738 	int error;
739 
740 	// Only stream is supported.
741 	if ((intptr_t)arg1 != SOCK_STREAM) {
742 		return EINVAL;
743 	}
744 
745 	// Get the generation count and the count of all vsock sockets.
746 	lck_rw_lock_shared(&vsockinfo.all_lock);
747 	uint64_t n = vsockinfo.all_pcb_count;
748 	vsock_gen_t gen_count = vsockinfo.vsock_gencnt;
749 	lck_rw_done(&vsockinfo.all_lock);
750 
751 	const size_t xpcb_len = sizeof(struct xvsockpcb);
752 	struct xvsockpgen xvg;
753 
754 	/*
755 	 * The process of preparing the PCB list is too time-consuming and
756 	 * resource-intensive to repeat twice on every request.
757 	 */
758 	if (req->oldptr == USER_ADDR_NULL) {
759 		req->oldidx = (size_t)(2 * sizeof(xvg) + (n + n / 8) * xpcb_len);
760 		return 0;
761 	}
762 
763 	if (req->newptr != USER_ADDR_NULL) {
764 		return EPERM;
765 	}
766 
767 	bzero(&xvg, sizeof(xvg));
768 	xvg.xvg_len = sizeof(xvg);
769 	xvg.xvg_count = n;
770 	xvg.xvg_gen = gen_count;
771 	xvg.xvg_sogen = so_gencnt;
772 	error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
773 	if (error) {
774 		return error;
775 	}
776 
777 	// Return if no sockets exist.
778 	if (n == 0) {
779 		return 0;
780 	}
781 
782 	lck_rw_lock_shared(&vsockinfo.all_lock);
783 
784 	n = 0;
785 	struct vsockpcb *pcb = NULL;
786 	TAILQ_FOREACH(pcb, &vsockinfo.all, all) {
787 		// Bail if there is not enough user buffer for this next socket.
788 		if (req->oldlen - req->oldidx - sizeof(xvg) < xpcb_len) {
789 			break;
790 		}
791 
792 		// Populate the socket structure.
793 		socket_lock(pcb->so, 1);
794 		if (pcb->vsock_gencnt <= gen_count) {
795 			struct xvsockpcb xpcb;
796 			bzero(&xpcb, xpcb_len);
797 			xpcb.xv_len = xpcb_len;
798 			xpcb.xv_vsockpp = (uint64_t)VM_KERNEL_ADDRHASH(pcb);
799 			xpcb.xvp_local_cid = pcb->local_address.cid;
800 			xpcb.xvp_local_port = pcb->local_address.port;
801 			xpcb.xvp_remote_cid = pcb->remote_address.cid;
802 			xpcb.xvp_remote_port = pcb->remote_address.port;
803 			xpcb.xvp_rxcnt = pcb->fwd_cnt;
804 			xpcb.xvp_txcnt = pcb->tx_cnt;
805 			xpcb.xvp_peer_rxhiwat = pcb->peer_buf_alloc;
806 			xpcb.xvp_peer_rxcnt = pcb->peer_fwd_cnt;
807 			xpcb.xvp_last_pid = pcb->so->last_pid;
808 			xpcb.xvp_gencnt = pcb->vsock_gencnt;
809 			if (pcb->so) {
810 				sotoxsocket(pcb->so, &xpcb.xv_socket);
811 			}
812 			socket_unlock(pcb->so, 1);
813 
814 			error = SYSCTL_OUT(req, &xpcb, xpcb_len);
815 			if (error != 0) {
816 				break;
817 			}
818 			n++;
819 		} else {
820 			socket_unlock(pcb->so, 1);
821 		}
822 	}
823 
824 	// Update the generation count to match the sockets being returned.
825 	gen_count = vsockinfo.vsock_gencnt;
826 
827 	lck_rw_done(&vsockinfo.all_lock);
828 
829 	if (!error) {
830 		/*
831 		 * Give the user an updated idea of our state.
832 		 * If the generation differs from what we told
833 		 * her before, she knows that something happened
834 		 * while we were processing this request, and it
835 		 * might be necessary to retry.
836 		 */
837 		bzero(&xvg, sizeof(xvg));
838 		xvg.xvg_len = sizeof(xvg);
839 		xvg.xvg_count = n;
840 		xvg.xvg_gen = gen_count;
841 		xvg.xvg_sogen = so_gencnt;
842 		error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
843 	}
844 
845 	return error;
846 }
847 
848 #ifdef SYSCTL_DECL
849 SYSCTL_NODE(_net, OID_AUTO, vsock, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock");
850 SYSCTL_UINT(_net_vsock, OID_AUTO, sendspace, CTLFLAG_RW | CTLFLAG_LOCKED,
851     &vsock_sendspace, 0, "Maximum outgoing vsock datagram size");
852 SYSCTL_UINT(_net_vsock, OID_AUTO, recvspace, CTLFLAG_RW | CTLFLAG_LOCKED,
853     &vsock_recvspace, 0, "Maximum incoming vsock datagram size");
854 SYSCTL_PROC(_net_vsock, OID_AUTO, pcblist,
855     CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
856     (caddr_t)(long)SOCK_STREAM, 0, vsock_pcblist, "S,xvsockpcb",
857     "List of active vsock sockets");
858 #endif
859 
860 /* VSock Protocol */
861 
862 static int
vsock_attach(struct socket * so,int proto,struct proc * p)863 vsock_attach(struct socket *so, int proto, struct proc *p)
864 {
865 	#pragma unused(proto, p)
866 
867 	// Reserve send and receive buffers.
868 	errno_t error = soreserve(so, vsock_sendspace, vsock_recvspace);
869 	if (error) {
870 		return error;
871 	}
872 
873 	// Attach should only be run once per socket.
874 	struct vsockpcb *pcb = sotovsockpcb(so);
875 	if (pcb) {
876 		return EINVAL;
877 	}
878 
879 	// Get the transport for this socket.
880 	struct vsock_transport *transport = os_atomic_load(&the_vsock_transport, relaxed);
881 	if (transport == NULL) {
882 		return ENODEV;
883 	}
884 
885 	// Initialize the vsock protocol control block.
886 	pcb = zalloc_flags(vsockpcb_zone, Z_WAITOK | Z_ZERO | Z_NOFAIL);
887 	pcb->so = so;
888 	pcb->transport = transport;
889 	pcb->local_address = (struct vsock_address) {
890 		.cid = VMADDR_CID_ANY,
891 		.port = VMADDR_PORT_ANY
892 	};
893 	pcb->remote_address = (struct vsock_address) {
894 		.cid = VMADDR_CID_ANY,
895 		.port = VMADDR_PORT_ANY
896 	};
897 	so->so_pcb = pcb;
898 
899 	// Tell the transport that this socket has attached.
900 	error = transport->attach_socket(transport->provider);
901 	if (error) {
902 		zfree(vsockpcb_zone, pcb);
903 		so->so_pcb = NULL;
904 		return error;
905 	}
906 
907 	// Add to the list of all vsock sockets.
908 	lck_rw_lock_exclusive(&vsockinfo.all_lock);
909 	TAILQ_INSERT_TAIL(&vsockinfo.all, pcb, all);
910 	vsockinfo.all_pcb_count++;
911 	pcb->vsock_gencnt = ++vsockinfo.vsock_gencnt;
912 	lck_rw_done(&vsockinfo.all_lock);
913 
914 	return 0;
915 }
916 
917 static int
vsock_control(struct socket * so,u_long cmd,caddr_t data,struct ifnet * ifp,struct proc * p)918 vsock_control(struct socket *so, u_long cmd, caddr_t data, struct ifnet *ifp, struct proc *p)
919 {
920 	#pragma unused(ifp)
921 
922 	VERIFY(so != NULL || p == kernproc);
923 
924 	if (cmd != IOCTL_VM_SOCKETS_GET_LOCAL_CID) {
925 		return EINVAL;
926 	}
927 
928 	struct vsock_transport *transport;
929 	if (so) {
930 		struct vsockpcb *pcb = sotovsockpcb(so);
931 		if (pcb == NULL) {
932 			return EINVAL;
933 		}
934 		transport = pcb->transport;
935 	} else {
936 		transport = os_atomic_load(&the_vsock_transport, relaxed);
937 	}
938 
939 	if (transport == NULL) {
940 		return ENODEV;
941 	}
942 
943 	uint32_t transport_cid;
944 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
945 	if (error) {
946 		return error;
947 	}
948 
949 	memcpy(data, &transport_cid, sizeof(transport_cid));
950 
951 	return 0;
952 }
953 
954 static int
vsock_detach(struct socket * so)955 vsock_detach(struct socket *so)
956 {
957 	struct vsockpcb *pcb = sotovsockpcb(so);
958 	if (pcb == NULL) {
959 		return EINVAL;
960 	}
961 
962 	vsock_unbind_pcb(pcb);
963 
964 	// Tell the transport that this socket has detached.
965 	struct vsock_transport *transport = pcb->transport;
966 	errno_t error = transport->detach_socket(transport->provider);
967 	if (error) {
968 		return error;
969 	}
970 
971 	// Remove from the list of all vsock sockets.
972 	lck_rw_lock_exclusive(&vsockinfo.all_lock);
973 	TAILQ_REMOVE(&vsockinfo.all, pcb, all);
974 	pcb->all.tqe_next = NULL;
975 	pcb->all.tqe_prev = NULL;
976 	vsockinfo.all_pcb_count--;
977 	vsockinfo.vsock_gencnt++;
978 	lck_rw_done(&vsockinfo.all_lock);
979 
980 	// Mark this socket for deallocation.
981 	so->so_flags |= SOF_PCBCLEARING;
982 
983 	return 0;
984 }
985 
986 static int
vsock_abort(struct socket * so)987 vsock_abort(struct socket *so)
988 {
989 	return vsock_detach(so);
990 }
991 
992 static int
vsock_bind(struct socket * so,struct sockaddr * nam,struct proc * p)993 vsock_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
994 {
995 	#pragma unused(p)
996 
997 	struct vsockpcb *pcb = sotovsockpcb(so);
998 	if (pcb == NULL) {
999 		return EINVAL;
1000 	}
1001 
1002 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1003 
1004 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr);
1005 	if (error) {
1006 		return error;
1007 	}
1008 
1009 	struct vsock_address laddr = (struct vsock_address) {
1010 		.cid = addr->svm_cid,
1011 		.port = addr->svm_port,
1012 	};
1013 
1014 	struct vsock_address raddr = (struct vsock_address) {
1015 		.cid = VMADDR_CID_ANY,
1016 		.port = VMADDR_PORT_ANY,
1017 	};
1018 
1019 	error = vsock_bind_address(pcb, laddr, raddr);
1020 	if (error) {
1021 		return error;
1022 	}
1023 
1024 	return 0;
1025 }
1026 
1027 static int
vsock_listen(struct socket * so,struct proc * p)1028 vsock_listen(struct socket *so, struct proc *p)
1029 {
1030 	#pragma unused(p)
1031 
1032 	struct vsockpcb *pcb = sotovsockpcb(so);
1033 	if (pcb == NULL) {
1034 		return EINVAL;
1035 	}
1036 
1037 	// Only stream is supported currently.
1038 	if (so->so_type != SOCK_STREAM) {
1039 		return EAFNOSUPPORT;
1040 	}
1041 
1042 	struct vsock_address *addr = &pcb->local_address;
1043 
1044 	if (addr->port == VMADDR_CID_ANY) {
1045 		return EFAULT;
1046 	}
1047 
1048 	struct vsock_transport *transport = pcb->transport;
1049 	uint32_t transport_cid;
1050 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
1051 	if (error) {
1052 		return error;
1053 	}
1054 
1055 	// Can listen on the transport's cid or any.
1056 	if (addr->cid != transport_cid && addr->cid != VMADDR_CID_ANY) {
1057 		return EFAULT;
1058 	}
1059 
1060 	return 0;
1061 }
1062 
1063 static int
vsock_accept(struct socket * so,struct sockaddr ** nam)1064 vsock_accept(struct socket *so, struct sockaddr **nam)
1065 {
1066 	struct vsockpcb *pcb = sotovsockpcb(so);
1067 	if (pcb == NULL) {
1068 		return EINVAL;
1069 	}
1070 
1071 	// Do not accept disconnected sockets.
1072 	if (so->so_state & SS_ISDISCONNECTED) {
1073 		return ECONNABORTED;
1074 	}
1075 
1076 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1077 
1078 	return 0;
1079 }
1080 
1081 static int
vsock_connect(struct socket * so,struct sockaddr * nam,struct proc * p)1082 vsock_connect(struct socket *so, struct sockaddr *nam, struct proc *p)
1083 {
1084 	#pragma unused(p)
1085 
1086 	struct vsockpcb *pcb = sotovsockpcb(so);
1087 	if (pcb == NULL) {
1088 		return EINVAL;
1089 	}
1090 
1091 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1092 
1093 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr);
1094 	if (error) {
1095 		return error;
1096 	}
1097 
1098 	uint32_t transport_cid;
1099 	struct vsock_transport *transport = pcb->transport;
1100 	error = transport->get_cid(transport->provider, &transport_cid);
1101 	if (error) {
1102 		return error;
1103 	}
1104 
1105 	// Only supporting connections to the host, hypervisor, or self for now.
1106 	if (addr->svm_cid != VMADDR_CID_HOST &&
1107 	    addr->svm_cid != VMADDR_CID_HYPERVISOR &&
1108 	    addr->svm_cid != transport_cid) {
1109 		return EFAULT;
1110 	}
1111 
1112 	soisconnecting(so);
1113 
1114 	// Set the remote and local address.
1115 	struct vsock_address remote_addr = (struct vsock_address) {
1116 		.cid = addr->svm_cid,
1117 		.port = addr->svm_port,
1118 	};
1119 
1120 	struct vsock_address local_addr = (struct vsock_address) {
1121 		.cid = transport_cid,
1122 		.port = VMADDR_PORT_ANY,
1123 	};
1124 
1125 	// Bind to the address.
1126 	error = vsock_bind_address(pcb, local_addr, remote_addr);
1127 	if (error) {
1128 		goto cleanup;
1129 	}
1130 
1131 	// Attempt a connection using the socket's transport.
1132 	error = vsock_pcb_connect(pcb);
1133 	if (error) {
1134 		goto cleanup;
1135 	}
1136 
1137 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1138 		// Don't wait for peer's response if non-blocking.
1139 		if (so->so_state & SS_NBIO) {
1140 			error = EINPROGRESS;
1141 			goto done;
1142 		}
1143 
1144 		struct timespec ts = (struct timespec) {
1145 			.tv_sec = so->so_snd.sb_timeo.tv_sec,
1146 			.tv_nsec = so->so_snd.sb_timeo.tv_usec * 1000,
1147 		};
1148 
1149 		lck_mtx_t *mutex_held;
1150 		if (so->so_proto->pr_getlock != NULL) {
1151 			mutex_held = (*so->so_proto->pr_getlock)(so, PR_F_WILLUNLOCK);
1152 		} else {
1153 			mutex_held = so->so_proto->pr_domain->dom_mtx;
1154 		}
1155 
1156 		// Wait until we receive a response to the connect request.
1157 		error = msleep((caddr_t)&so->so_timeo, mutex_held, PSOCK | PCATCH, "vsock_connect", &ts);
1158 		if (error) {
1159 			if (error == EAGAIN) {
1160 				error = ETIMEDOUT;
1161 			}
1162 			goto cleanup;
1163 		}
1164 	}
1165 
1166 cleanup:
1167 	if (so->so_error && !error) {
1168 		error = so->so_error;
1169 		so->so_error = 0;
1170 	}
1171 	if (!error) {
1172 		error = !(so->so_state & SS_ISCONNECTED);
1173 	}
1174 	if (error) {
1175 		vsock_unbind_pcb(pcb);
1176 	}
1177 
1178 done:
1179 	return error;
1180 }
1181 
1182 static int
vsock_disconnect(struct socket * so)1183 vsock_disconnect(struct socket *so)
1184 {
1185 	struct vsockpcb *pcb = sotovsockpcb(so);
1186 	if (pcb == NULL) {
1187 		return EINVAL;
1188 	}
1189 
1190 	return vsock_disconnect_pcb(pcb);
1191 }
1192 
1193 static int
vsock_sockaddr(struct socket * so,struct sockaddr ** nam)1194 vsock_sockaddr(struct socket *so, struct sockaddr **nam)
1195 {
1196 	struct vsockpcb *pcb = sotovsockpcb(so);
1197 	if (pcb == NULL) {
1198 		return EINVAL;
1199 	}
1200 
1201 	*nam = vsock_new_sockaddr(&pcb->local_address);
1202 
1203 	return 0;
1204 }
1205 
1206 static int
vsock_peeraddr(struct socket * so,struct sockaddr ** nam)1207 vsock_peeraddr(struct socket *so, struct sockaddr **nam)
1208 {
1209 	struct vsockpcb *pcb = sotovsockpcb(so);
1210 	if (pcb == NULL) {
1211 		return EINVAL;
1212 	}
1213 
1214 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1215 
1216 	return 0;
1217 }
1218 
1219 static int
vsock_send(struct socket * so,int flags,struct mbuf * m,struct sockaddr * nam,struct mbuf * control,proc_t p)1220 vsock_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *nam, struct mbuf *control, proc_t p)
1221 {
1222 	#pragma unused(flags, nam, p)
1223 
1224 	struct vsockpcb *pcb = sotovsockpcb(so);
1225 	if (pcb == NULL || m == NULL) {
1226 		return EINVAL;
1227 	}
1228 
1229 	if (control != NULL) {
1230 		m_freem(control);
1231 		return EOPNOTSUPP;
1232 	}
1233 
1234 	// Ensure this socket is connected.
1235 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1236 		if (m != NULL) {
1237 			mbuf_freem_list(m);
1238 		}
1239 		return EPERM;
1240 	}
1241 
1242 	errno_t error;
1243 
1244 	// rdar://84098487 (SEED: Web: Virtio-socket sent data lost after 128KB)
1245 	// For writes larger than the default `sosendmaxchain` of 65536, vsock_send() is called multiple times per write().
1246 	// Only the first call to vsock_send() is passed a valid mbuf packet, while subsequent calls are not marked as a packet
1247 	// with a valid length. We should mark all mbufs as a packet and set the correct packet length so that the downstream
1248 	// socket transport layer can correctly generate physical segments.
1249 	if (!(mbuf_flags(m) & MBUF_PKTHDR)) {
1250 		if (!(mbuf_flags(m) & M_EXT)) {
1251 			struct mbuf *header = NULL;
1252 			MGETHDR(header, M_WAITOK, MT_HEADER);
1253 			if (header == NULL) {
1254 				if (m != NULL) {
1255 					mbuf_freem_list(m);
1256 				}
1257 				return ENOBUFS;
1258 			}
1259 			header->m_next = m;
1260 			m = header;
1261 		} else {
1262 			mbuf_setflags(m, mbuf_flags(m) | MBUF_PKTHDR);
1263 		}
1264 
1265 		size_t len = 0;
1266 		struct mbuf *next = m;
1267 		while (next) {
1268 			len += mbuf_len(next);
1269 			next = mbuf_next(next);
1270 		}
1271 		mbuf_pkthdr_setlen(m, len);
1272 	}
1273 
1274 	const size_t len = mbuf_pkthdr_len(m);
1275 	uint32_t free_space = vsock_get_peer_space(pcb);
1276 
1277 	// Ensure the peer has enough space in their receive buffer.
1278 	while (len > free_space) {
1279 		// Record the number of free peer bytes necessary before we can send.
1280 		if (len > pcb->waiting_send_size) {
1281 			pcb->waiting_send_size = len;
1282 		}
1283 
1284 		// Send a credit request.
1285 		error = vsock_pcb_credit_request(pcb);
1286 		if (error) {
1287 			if (m != NULL) {
1288 				mbuf_freem_list(m);
1289 			}
1290 			return error;
1291 		}
1292 
1293 		// Check again in case free space was automatically updated in loopback case.
1294 		free_space = vsock_get_peer_space(pcb);
1295 		if (len <= free_space) {
1296 			pcb->waiting_send_size = 0;
1297 			break;
1298 		}
1299 
1300 		// Bail if this is a non-blocking socket.
1301 		if (so->so_state & SS_NBIO) {
1302 			if (m != NULL) {
1303 				mbuf_freem_list(m);
1304 			}
1305 			return EWOULDBLOCK;
1306 		}
1307 
1308 		// Wait until our peer has enough free space in their receive buffer.
1309 		error = sbwait(&so->so_snd);
1310 		pcb->waiting_send_size = 0;
1311 		if (error) {
1312 			if (m != NULL) {
1313 				mbuf_freem_list(m);
1314 			}
1315 			return error;
1316 		}
1317 
1318 		// Bail if an error occured or we can't send more.
1319 		if (so->so_state & SS_CANTSENDMORE) {
1320 			if (m != NULL) {
1321 				mbuf_freem_list(m);
1322 			}
1323 			return EPIPE;
1324 		} else if (so->so_error) {
1325 			error = so->so_error;
1326 			so->so_error = 0;
1327 			if (m != NULL) {
1328 				mbuf_freem_list(m);
1329 			}
1330 			return error;
1331 		}
1332 
1333 		free_space = vsock_get_peer_space(pcb);
1334 	}
1335 
1336 	// Send a payload over the transport.
1337 	error = vsock_pcb_send(pcb, m);
1338 	if (error) {
1339 		return error;
1340 	}
1341 
1342 	pcb->tx_cnt += len;
1343 
1344 	return 0;
1345 }
1346 
1347 static int
vsock_shutdown(struct socket * so)1348 vsock_shutdown(struct socket *so)
1349 {
1350 	struct vsockpcb *pcb = sotovsockpcb(so);
1351 	if (pcb == NULL) {
1352 		return EINVAL;
1353 	}
1354 
1355 	socantsendmore(so);
1356 
1357 	// Tell peer we will no longer send.
1358 	errno_t error = vsock_pcb_shutdown_send(pcb);
1359 	if (error) {
1360 		return error;
1361 	}
1362 
1363 	return 0;
1364 }
1365 
1366 static int
vsock_soreceive(struct socket * so,struct sockaddr ** psa,struct uio * uio,struct mbuf ** mp0,struct mbuf ** controlp,int * flagsp)1367 vsock_soreceive(struct socket *so, struct sockaddr **psa, struct uio *uio,
1368     struct mbuf **mp0, struct mbuf **controlp, int *flagsp)
1369 {
1370 	struct vsockpcb *pcb = sotovsockpcb(so);
1371 	if (pcb == NULL) {
1372 		return EINVAL;
1373 	}
1374 
1375 	user_ssize_t length = uio_resid(uio);
1376 	int result = soreceive(so, psa, uio, mp0, controlp, flagsp);
1377 	length -= uio_resid(uio);
1378 
1379 	socket_lock(so, 1);
1380 
1381 	pcb->fwd_cnt += length;
1382 
1383 	const uint32_t threshold = VSOCK_MAX_PACKET_SIZE;
1384 
1385 	// Send a credit update if is possible that the peer will no longer send.
1386 	if ((pcb->fwd_cnt - pcb->last_fwd_cnt + threshold) >= pcb->last_buf_alloc) {
1387 		errno_t error = vsock_pcb_credit_update(pcb);
1388 		if (!result && error) {
1389 			result = error;
1390 		}
1391 	}
1392 
1393 	socket_unlock(so, 1);
1394 
1395 	return result;
1396 }
1397 
1398 static struct pr_usrreqs vsock_usrreqs = {
1399 	.pru_abort =            vsock_abort,
1400 	.pru_attach =           vsock_attach,
1401 	.pru_control =          vsock_control,
1402 	.pru_detach =           vsock_detach,
1403 	.pru_bind =             vsock_bind,
1404 	.pru_listen =           vsock_listen,
1405 	.pru_accept =           vsock_accept,
1406 	.pru_connect =          vsock_connect,
1407 	.pru_disconnect =       vsock_disconnect,
1408 	.pru_send =             vsock_send,
1409 	.pru_shutdown =         vsock_shutdown,
1410 	.pru_sockaddr =         vsock_sockaddr,
1411 	.pru_peeraddr =         vsock_peeraddr,
1412 	.pru_sosend =           sosend,
1413 	.pru_soreceive =        vsock_soreceive,
1414 };
1415 
1416 static void
vsock_init(struct protosw * pp,struct domain * dp)1417 vsock_init(struct protosw *pp, struct domain *dp)
1418 {
1419 	#pragma unused(dp)
1420 
1421 	static int vsock_initialized = 0;
1422 	VERIFY((pp->pr_flags & (PR_INITIALIZED | PR_ATTACHED)) == PR_ATTACHED);
1423 	if (!os_atomic_cmpxchg((volatile int *)&vsock_initialized, 0, 1, acq_rel)) {
1424 		return;
1425 	}
1426 
1427 	// Setup VSock protocol info struct.
1428 	lck_rw_init(&vsockinfo.all_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1429 	lck_rw_init(&vsockinfo.bound_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1430 	lck_mtx_init(&vsockinfo.port_lock, &vsock_lock_grp, LCK_ATTR_NULL);
1431 	TAILQ_INIT(&vsockinfo.all);
1432 	LIST_INIT(&vsockinfo.bound);
1433 	vsockinfo.last_port = VMADDR_PORT_ANY;
1434 }
1435 
1436 static int
vsock_sofreelastref(struct socket * so,int dealloc)1437 vsock_sofreelastref(struct socket *so, int dealloc)
1438 {
1439 	socket_lock_assert_owned(so);
1440 
1441 	struct vsockpcb *pcb = sotovsockpcb(so);
1442 	if (pcb != NULL) {
1443 		zfree(vsockpcb_zone, pcb);
1444 	}
1445 
1446 	so->so_pcb = NULL;
1447 	sofreelastref(so, dealloc);
1448 
1449 	return 0;
1450 }
1451 
1452 static int
vsock_unlock(struct socket * so,int refcount,void * lr_saved)1453 vsock_unlock(struct socket *so, int refcount, void *lr_saved)
1454 {
1455 	lck_mtx_t *mutex_held = so->so_proto->pr_domain->dom_mtx;
1456 #ifdef MORE_LOCKING_DEBUG
1457 	LCK_MTX_ASSERT(mutex_held, LCK_MTX_ASSERT_OWNED);
1458 #endif
1459 	so->unlock_lr[so->next_unlock_lr] = lr_saved;
1460 	so->next_unlock_lr = (so->next_unlock_lr + 1) % SO_LCKDBG_MAX;
1461 
1462 	if (refcount) {
1463 		if (so->so_usecount <= 0) {
1464 			panic("%s: bad refcount=%d so=%p (%d, %d, %d) "
1465 			    "lrh=%s", __func__, so->so_usecount, so,
1466 			    SOCK_DOM(so), so->so_type,
1467 			    SOCK_PROTO(so), solockhistory_nr(so));
1468 			/* NOTREACHED */
1469 		}
1470 
1471 		so->so_usecount--;
1472 		if (so->so_usecount == 0) {
1473 			vsock_sofreelastref(so, 1);
1474 		}
1475 	}
1476 	lck_mtx_unlock(mutex_held);
1477 
1478 	return 0;
1479 }
1480 
1481 static struct protosw vsocksw[] = {
1482 	{
1483 		.pr_type =              SOCK_STREAM,
1484 		.pr_protocol =          0,
1485 		.pr_flags =             PR_CONNREQUIRED | PR_WANTRCVD,
1486 		.pr_init =              vsock_init,
1487 		.pr_unlock =            vsock_unlock,
1488 		.pr_usrreqs =           &vsock_usrreqs,
1489 	}
1490 };
1491 
1492 static const int vsock_proto_count = (sizeof(vsocksw) / sizeof(struct protosw));
1493 
1494 /* VSock Domain */
1495 
1496 static struct domain *vsock_domain = NULL;
1497 
1498 static void
vsock_dinit(struct domain * dp)1499 vsock_dinit(struct domain *dp)
1500 {
1501 	// The VSock domain is initialized with a singleton pattern.
1502 	VERIFY(!(dp->dom_flags & DOM_INITIALIZED));
1503 	VERIFY(vsock_domain == NULL);
1504 	vsock_domain = dp;
1505 
1506 	// Add protocols and initialize.
1507 	for (int i = 0; i < vsock_proto_count; i++) {
1508 		net_add_proto((struct protosw *)&vsocksw[i], dp, 1);
1509 	}
1510 }
1511 
1512 struct domain vsockdomain_s = {
1513 	.dom_family =           PF_VSOCK,
1514 	.dom_name =             "vsock",
1515 	.dom_init =             vsock_dinit,
1516 	.dom_maxrtkey =         sizeof(struct sockaddr_vm),
1517 	.dom_protohdrlen =      sizeof(struct sockaddr_vm),
1518 };
1519