xref: /xnu-12377.41.6/bsd/kern/vsock_domain.c (revision bbb1b6f9e71b8cdde6e5cd6f4841f207dee3d828)
1 /*
2  * Copyright (c) 2020 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <sys/domain.h>
30 #include <sys/socket.h>
31 #include <sys/protosw.h>
32 #include <sys/mcache.h>
33 #include <sys/systm.h>
34 #include <sys/sysctl.h>
35 #include <sys/random.h>
36 #include <sys/mbuf.h>
37 #include <sys/vsock_domain.h>
38 #include <sys/vsock_transport.h>
39 #include <kern/task.h>
40 #include <kern/zalloc.h>
41 #include <kern/locks.h>
42 #include <machine/atomic.h>
43 #include <IOKit/IOBSD.h>
44 
45 #define sotovsockpcb(so) ((struct vsockpcb *)(so)->so_pcb)
46 
47 #define VSOCK_PORT_RESERVED 1024
48 #define VSOCK_PRIVATE_ENTITLEMENT "com.apple.private.vsock"
49 
50 /* VSock Protocol Globals */
51 
52 static struct vsock_transport * _Atomic the_vsock_transport[VSOCK_PROTO_MAX];
53 static ZONE_DEFINE_TYPE(vsockpcb_zone, "vsockpcbzone", struct vsockpcb, ZC_NONE);
54 static struct vsockpcbinfo vsockinfo[VSOCK_PROTO_MAX];
55 
56 static uint32_t vsock_sendspace[VSOCK_PROTO_MAX];
57 static uint32_t vsock_recvspace[VSOCK_PROTO_MAX];
58 
59 /* VSock Private Entitlements */
60 
61 static errno_t
vsock_validate_entitlements(uint16_t protocol,struct proc * p)62 vsock_validate_entitlements(uint16_t protocol, struct proc *p)
63 {
64 	if (protocol != VSOCK_PROTO_PRIVATE) {
65 		return 0;
66 	}
67 
68 	if (!p) {
69 		p = current_proc();
70 	}
71 
72 	if (p == kernproc) {
73 		// Assume kernel callers are entitled.
74 		return 0;
75 	}
76 
77 	if (!IOTaskHasEntitlement(proc_task(p), VSOCK_PRIVATE_ENTITLEMENT)) {
78 		return EPERM;
79 	}
80 
81 	return 0;
82 }
83 
84 /* VSock PCB Helpers */
85 
86 static uint32_t
vsock_get_peer_space(struct vsockpcb * _Nonnull pcb)87 vsock_get_peer_space(struct vsockpcb *_Nonnull pcb)
88 {
89 	VERIFY(pcb != NULL);
90 	return pcb->peer_buf_alloc - (pcb->tx_cnt - pcb->peer_fwd_cnt);
91 }
92 
93 static struct vsockpcb *
vsock_get_matching_pcb(struct vsock_address src,struct vsock_address dst,uint16_t protocol)94 vsock_get_matching_pcb(struct vsock_address src, struct vsock_address dst, uint16_t protocol)
95 {
96 	struct vsockpcb *preferred = NULL;
97 	struct vsockpcb *match = NULL;
98 	struct vsockpcb *pcb = NULL;
99 
100 	lck_rw_lock_shared(&vsockinfo[protocol].bound_lock);
101 	LIST_FOREACH(pcb, &vsockinfo[protocol].bound, bound) {
102 		// Source cid and port must match. Only destination port must match. (Allows for a changing CID during migration)
103 		socket_lock(pcb->so, 1);
104 		if ((pcb->so->so_state & SS_ISCONNECTED || pcb->so->so_state & SS_ISCONNECTING) &&
105 		    pcb->local_address.cid == src.cid && pcb->local_address.port == src.port &&
106 		    pcb->remote_address.port == dst.port) {
107 			preferred = pcb;
108 			break;
109 		} else if ((pcb->local_address.cid == src.cid || pcb->local_address.cid == VMADDR_CID_ANY) &&
110 		    pcb->local_address.port == src.port) {
111 			match = pcb;
112 		}
113 		socket_unlock(pcb->so, 1);
114 	}
115 	if (!preferred && match) {
116 		socket_lock(match->so, 1);
117 		preferred = match;
118 	}
119 	lck_rw_done(&vsockinfo[protocol].bound_lock);
120 
121 	return preferred;
122 }
123 
124 static errno_t
vsock_bind_address_if_free(struct vsockpcb * _Nonnull pcb,uint32_t local_cid,uint32_t local_port,uint32_t remote_cid,uint32_t remote_port)125 vsock_bind_address_if_free(struct vsockpcb *_Nonnull pcb, uint32_t local_cid, uint32_t local_port, uint32_t remote_cid, uint32_t remote_port)
126 {
127 	VERIFY(pcb != NULL);
128 	socket_lock_assert_owned(pcb->so);
129 
130 	// Privileged ports.
131 	if (local_port != VMADDR_PORT_ANY && local_port < VSOCK_PORT_RESERVED &&
132 	    current_task() != kernel_task && proc_suser(current_proc()) != 0) {
133 		return EACCES;
134 	}
135 
136 	bool taken = false;
137 	const bool check_remote = (remote_cid != VMADDR_CID_ANY && remote_port != VMADDR_PORT_ANY);
138 	const uint16_t protocol = pcb->so->so_protocol;
139 
140 	struct vsockpcb *pcb_match = NULL;
141 
142 	socket_unlock(pcb->so, 0);
143 	lck_rw_lock_exclusive(&vsockinfo[protocol].bound_lock);
144 	LIST_FOREACH(pcb_match, &vsockinfo[protocol].bound, bound) {
145 		socket_lock(pcb_match->so, 1);
146 		if (pcb == pcb_match ||
147 		    (!check_remote && pcb_match->local_address.port == local_port) ||
148 		    (check_remote && pcb_match->local_address.port == local_port &&
149 		    pcb_match->remote_address.cid == remote_cid && pcb_match->remote_address.port == remote_port)) {
150 			socket_unlock(pcb_match->so, 1);
151 			taken = true;
152 			break;
153 		}
154 		socket_unlock(pcb_match->so, 1);
155 	}
156 	socket_lock(pcb->so, 0);
157 	if (!taken) {
158 		pcb->local_address = (struct vsock_address) { .cid = local_cid, .port = local_port };
159 		pcb->remote_address = (struct vsock_address) { .cid = remote_cid, .port = remote_port };
160 		LIST_INSERT_HEAD(&vsockinfo[protocol].bound, pcb, bound);
161 	}
162 	lck_rw_done(&vsockinfo[protocol].bound_lock);
163 
164 	return taken ? EADDRINUSE : 0;
165 }
166 
167 static errno_t
vsock_bind_address(struct vsockpcb * pcb,struct vsock_address laddr,struct vsock_address raddr)168 vsock_bind_address(struct vsockpcb *pcb, struct vsock_address laddr, struct vsock_address raddr)
169 {
170 	if (!pcb) {
171 		return EINVAL;
172 	}
173 
174 	socket_lock_assert_owned(pcb->so);
175 
176 	// Certain CIDs are reserved.
177 	if (laddr.cid == VMADDR_CID_HYPERVISOR || laddr.cid == VMADDR_CID_RESERVED || laddr.cid == VMADDR_CID_HOST) {
178 		return EADDRNOTAVAIL;
179 	}
180 
181 	// Remote address must be fully specified or not specified at all.
182 	if ((raddr.cid == VMADDR_CID_ANY) ^ (raddr.port == VMADDR_PORT_ANY)) {
183 		return EINVAL;
184 	}
185 
186 	// Cannot bind if already bound.
187 	if (pcb->local_address.port != VMADDR_PORT_ANY) {
188 		return EINVAL;
189 	}
190 
191 	uint32_t transport_cid;
192 	struct vsock_transport *transport = pcb->transport;
193 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
194 	if (error) {
195 		return error;
196 	}
197 
198 	// Local CID must be this transport's CID or any.
199 	if (laddr.cid != transport_cid && laddr.cid != VMADDR_CID_ANY) {
200 		return EINVAL;
201 	}
202 
203 	if (laddr.port != VMADDR_PORT_ANY) {
204 		error = vsock_bind_address_if_free(pcb, laddr.cid, laddr.port, raddr.cid, raddr.port);
205 	} else {
206 		const uint16_t protocol = pcb->so->so_protocol;
207 
208 		socket_unlock(pcb->so, 0);
209 		lck_mtx_lock(&vsockinfo[protocol].port_lock);
210 		socket_lock(pcb->so, 0);
211 
212 		const uint32_t first = VSOCK_PORT_RESERVED;
213 		const uint32_t last = VMADDR_PORT_ANY - 1;
214 		uint32_t count = last - first + 1;
215 		uint32_t *last_port = &vsockinfo[protocol].last_port;
216 
217 		if (pcb->so->so_flags & SOF_BINDRANDOMPORT) {
218 			uint32_t random = 0;
219 			read_frandom(&random, sizeof(random));
220 			*last_port = first + (random % count);
221 		}
222 
223 		do {
224 			if (count == 0) {
225 				lck_mtx_unlock(&vsockinfo[protocol].port_lock);
226 				return EADDRNOTAVAIL;
227 			}
228 			count--;
229 
230 			++*last_port;
231 			if (*last_port < first || *last_port > last) {
232 				*last_port = first;
233 			}
234 
235 			error = vsock_bind_address_if_free(pcb, laddr.cid, *last_port, raddr.cid, raddr.port);
236 		} while (error);
237 
238 		lck_mtx_unlock(&vsockinfo[protocol].port_lock);
239 	}
240 
241 	return error;
242 }
243 
244 static void
vsock_unbind_pcb_locked(struct vsockpcb * pcb,bool is_locked)245 vsock_unbind_pcb_locked(struct vsockpcb *pcb, bool is_locked)
246 {
247 	if (!pcb) {
248 		return;
249 	}
250 
251 	struct socket *so = pcb->so;
252 	socket_lock_assert_owned(so);
253 
254 	// Bail if disconnect and already unbound.
255 	if (so->so_state & SS_ISDISCONNECTED) {
256 		assert(pcb->bound.le_next == NULL);
257 		assert(pcb->bound.le_prev == NULL);
258 		return;
259 	}
260 
261 	const uint16_t protocol = so->so_protocol;
262 
263 	if (!is_locked) {
264 		socket_unlock(so, 0);
265 		lck_rw_lock_exclusive(&vsockinfo[protocol].bound_lock);
266 		socket_lock(so, 0);
267 
268 		// Case where some other thread also called unbind() on this socket while waiting to acquire its lock.
269 		if (!pcb->bound.le_prev) {
270 			soisdisconnected(so);
271 			lck_rw_done(&vsockinfo[protocol].bound_lock);
272 			return;
273 		}
274 	}
275 
276 	soisdisconnected(so);
277 
278 	LIST_REMOVE(pcb, bound);
279 	pcb->bound.le_next = NULL;
280 	pcb->bound.le_prev = NULL;
281 
282 	if (!is_locked) {
283 		lck_rw_done(&vsockinfo[protocol].bound_lock);
284 	}
285 }
286 
287 static void
vsock_unbind_pcb(struct vsockpcb * pcb)288 vsock_unbind_pcb(struct vsockpcb *pcb)
289 {
290 	vsock_unbind_pcb_locked(pcb, false);
291 }
292 
293 static struct sockaddr *
vsock_new_sockaddr(struct vsock_address * address)294 vsock_new_sockaddr(struct vsock_address *address)
295 {
296 	if (!address) {
297 		return NULL;
298 	}
299 
300 	struct sockaddr_vm *addr;
301 	addr = (struct sockaddr_vm *)alloc_sockaddr(sizeof(*addr),
302 	    Z_WAITOK | Z_NOFAIL);
303 
304 	addr->svm_family = AF_VSOCK;
305 	addr->svm_port = address->port;
306 	addr->svm_cid = address->cid;
307 
308 	return (struct sockaddr *)addr;
309 }
310 
311 static errno_t
vsock_pcb_send_message(struct vsockpcb * pcb,enum vsock_operation operation,mbuf_t m)312 vsock_pcb_send_message(struct vsockpcb *pcb, enum vsock_operation operation, mbuf_t m)
313 {
314 	if (!pcb) {
315 		if (m != NULL) {
316 			mbuf_freem_list(m);
317 		}
318 		return EINVAL;
319 	}
320 
321 	socket_lock_assert_owned(pcb->so);
322 
323 	errno_t error;
324 
325 	struct vsock_address dst = pcb->remote_address;
326 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
327 		if (m != NULL) {
328 			mbuf_freem_list(m);
329 		}
330 		return EINVAL;
331 	}
332 
333 	struct vsock_address src = pcb->local_address;
334 	if (src.cid == VMADDR_CID_ANY) {
335 		uint32_t transport_cid;
336 		struct vsock_transport *transport = pcb->transport;
337 		error = transport->get_cid(transport->provider, &transport_cid);
338 		if (error) {
339 			if (m != NULL) {
340 				mbuf_freem_list(m);
341 			}
342 			return error;
343 		}
344 		src.cid = transport_cid;
345 	}
346 
347 	const uint16_t protocol = pcb->so->so_protocol;
348 	const uint32_t buf_alloc = pcb->so->so_rcv.sb_hiwat;
349 	const uint32_t fwd_cnt = pcb->fwd_cnt;
350 
351 	if (src.cid == dst.cid) {
352 		pcb->last_buf_alloc = buf_alloc;
353 		pcb->last_fwd_cnt = fwd_cnt;
354 
355 		socket_unlock(pcb->so, 0);
356 		error = vsock_put_message(src, dst, operation, buf_alloc, fwd_cnt, m, protocol);
357 		socket_lock(pcb->so, 0);
358 	} else {
359 		struct vsock_transport *transport = pcb->transport;
360 		error = transport->put_message(transport->provider, src, dst, operation, buf_alloc, fwd_cnt, m);
361 
362 		if (!error) {
363 			pcb->last_buf_alloc = buf_alloc;
364 			pcb->last_fwd_cnt = fwd_cnt;
365 		}
366 	}
367 
368 	return error;
369 }
370 
371 static errno_t
vsock_pcb_reset_address(struct vsock_address src,struct vsock_address dst,uint16_t protocol)372 vsock_pcb_reset_address(struct vsock_address src, struct vsock_address dst, uint16_t protocol)
373 {
374 	if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
375 		return EINVAL;
376 	}
377 
378 	errno_t error = 0;
379 	struct vsock_transport *transport = NULL;
380 
381 	if (src.cid == VMADDR_CID_ANY) {
382 		transport = os_atomic_load(&the_vsock_transport[protocol], relaxed);
383 		if (transport == NULL) {
384 			return ENODEV;
385 		}
386 
387 		uint32_t transport_cid;
388 		error = transport->get_cid(transport->provider, &transport_cid);
389 		if (error) {
390 			return error;
391 		}
392 		src.cid = transport_cid;
393 	}
394 
395 	if (src.cid == dst.cid) {
396 		// Reset both sockets.
397 		struct vsockpcb *pcb = vsock_get_matching_pcb(src, dst, protocol);
398 		if (pcb) {
399 			socket_lock_assert_owned(pcb->so);
400 			vsock_unbind_pcb(pcb);
401 			socket_unlock(pcb->so, 1);
402 		}
403 	} else {
404 		if (!transport) {
405 			transport = os_atomic_load(&the_vsock_transport[protocol], relaxed);
406 			if (transport == NULL) {
407 				return ENODEV;
408 			}
409 		}
410 		error = transport->put_message(transport->provider, src, dst, VSOCK_RESET, 0, 0, NULL);
411 	}
412 
413 	return error;
414 }
415 
416 static errno_t
vsock_pcb_safe_reset_address(struct vsockpcb * pcb,struct vsock_address src,struct vsock_address dst,uint16_t protocol)417 vsock_pcb_safe_reset_address(struct vsockpcb *pcb, struct vsock_address src, struct vsock_address dst, uint16_t protocol)
418 {
419 	if (pcb) {
420 		socket_lock_assert_owned(pcb->so);
421 		socket_unlock(pcb->so, 0);
422 	}
423 	errno_t error = vsock_pcb_reset_address(src, dst, protocol);
424 	if (pcb) {
425 		socket_lock(pcb->so, 0);
426 	}
427 	return error;
428 }
429 
430 static errno_t
vsock_pcb_connect(struct vsockpcb * pcb)431 vsock_pcb_connect(struct vsockpcb *pcb)
432 {
433 	return vsock_pcb_send_message(pcb, VSOCK_REQUEST, NULL);
434 }
435 
436 static errno_t
vsock_pcb_respond(struct vsockpcb * pcb)437 vsock_pcb_respond(struct vsockpcb *pcb)
438 {
439 	return vsock_pcb_send_message(pcb, VSOCK_RESPONSE, NULL);
440 }
441 
442 static errno_t
vsock_pcb_send(struct vsockpcb * pcb,mbuf_t m)443 vsock_pcb_send(struct vsockpcb *pcb, mbuf_t m)
444 {
445 	return vsock_pcb_send_message(pcb, VSOCK_PAYLOAD, m);
446 }
447 
448 static errno_t
vsock_pcb_shutdown_send(struct vsockpcb * pcb)449 vsock_pcb_shutdown_send(struct vsockpcb *pcb)
450 {
451 	return vsock_pcb_send_message(pcb, VSOCK_SHUTDOWN_SEND, NULL);
452 }
453 
454 static errno_t
vsock_pcb_reset(struct vsockpcb * pcb)455 vsock_pcb_reset(struct vsockpcb *pcb)
456 {
457 	return vsock_pcb_send_message(pcb, VSOCK_RESET, NULL);
458 }
459 
460 static errno_t
vsock_pcb_credit_update(struct vsockpcb * pcb)461 vsock_pcb_credit_update(struct vsockpcb *pcb)
462 {
463 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_UPDATE, NULL);
464 }
465 
466 static errno_t
vsock_pcb_credit_update_if_needed(struct vsockpcb * _Nonnull pcb)467 vsock_pcb_credit_update_if_needed(struct vsockpcb *_Nonnull pcb)
468 {
469 	VERIFY(pcb != NULL);
470 
471 	// Sends a credit update if the credit values have changed since the last sent message.
472 	if (pcb->so->so_rcv.sb_hiwat != pcb->last_buf_alloc || pcb->fwd_cnt != pcb->last_fwd_cnt) {
473 		return vsock_pcb_credit_update(pcb);
474 	}
475 	return 0;
476 }
477 
478 static errno_t
vsock_pcb_credit_request(struct vsockpcb * pcb)479 vsock_pcb_credit_request(struct vsockpcb *pcb)
480 {
481 	return vsock_pcb_send_message(pcb, VSOCK_CREDIT_REQUEST, NULL);
482 }
483 
484 static errno_t
vsock_disconnect_pcb_common(struct vsockpcb * pcb,bool is_locked)485 vsock_disconnect_pcb_common(struct vsockpcb *pcb, bool is_locked)
486 {
487 	socket_lock_assert_owned(pcb->so);
488 	vsock_unbind_pcb_locked(pcb, is_locked);
489 	return vsock_pcb_reset(pcb);
490 }
491 
492 static errno_t
vsock_disconnect_pcb_locked(struct vsockpcb * pcb)493 vsock_disconnect_pcb_locked(struct vsockpcb *pcb)
494 {
495 	return vsock_disconnect_pcb_common(pcb, true);
496 }
497 
498 static errno_t
vsock_disconnect_pcb(struct vsockpcb * pcb)499 vsock_disconnect_pcb(struct vsockpcb *pcb)
500 {
501 	return vsock_disconnect_pcb_common(pcb, false);
502 }
503 
504 static errno_t
vsock_sockaddr_vm_validate(struct vsockpcb * pcb,struct sockaddr_vm * addr,struct proc * p)505 vsock_sockaddr_vm_validate(struct vsockpcb *pcb, struct sockaddr_vm *addr, struct proc *p)
506 {
507 	if (!pcb || !pcb->so || !addr) {
508 		return EINVAL;
509 	}
510 
511 	// Validate address length.
512 	if (addr->svm_len < sizeof(struct sockaddr_vm)) {
513 		return EINVAL;
514 	}
515 
516 	// Validate address family.
517 	if (addr->svm_family != AF_UNSPEC && addr->svm_family != AF_VSOCK) {
518 		return EAFNOSUPPORT;
519 	}
520 
521 	// Only stream is supported currently.
522 	if (pcb->so->so_type != SOCK_STREAM) {
523 		return EAFNOSUPPORT;
524 	}
525 
526 	errno_t error = vsock_validate_entitlements(pcb->so->so_protocol, p);
527 	if (error) {
528 		return error;
529 	}
530 
531 	return 0;
532 }
533 
534 /* VSock Receive Handlers */
535 
536 static errno_t
vsock_put_message_connected(struct vsockpcb * _Nonnull pcb,enum vsock_operation op,mbuf_t m)537 vsock_put_message_connected(struct vsockpcb *_Nonnull pcb, enum vsock_operation op, mbuf_t m)
538 {
539 	VERIFY(pcb != NULL);
540 	socket_lock_assert_owned(pcb->so);
541 
542 	errno_t error = 0;
543 
544 	switch (op) {
545 	case VSOCK_SHUTDOWN:
546 		socantsendmore(pcb->so);
547 		socantrcvmore(pcb->so);
548 		break;
549 	case VSOCK_SHUTDOWN_RECEIVE:
550 		socantsendmore(pcb->so);
551 		break;
552 	case VSOCK_SHUTDOWN_SEND:
553 		socantrcvmore(pcb->so);
554 		break;
555 	case VSOCK_PAYLOAD:
556 		// Add data to the receive queue then wakeup any reading threads.
557 		error = !sbappendstream(&pcb->so->so_rcv, m);
558 		if (!error) {
559 			sorwakeup(pcb->so);
560 		}
561 		break;
562 	case VSOCK_RESET:
563 		vsock_unbind_pcb(pcb);
564 		break;
565 	default:
566 		error = ENOTSUP;
567 		break;
568 	}
569 
570 	return error;
571 }
572 
573 static errno_t
vsock_put_message_connecting(struct vsockpcb * _Nonnull pcb,enum vsock_operation op)574 vsock_put_message_connecting(struct vsockpcb *_Nonnull pcb, enum vsock_operation op)
575 {
576 	VERIFY(pcb != NULL);
577 	socket_lock_assert_owned(pcb->so);
578 
579 	errno_t error = 0;
580 
581 	switch (op) {
582 	case VSOCK_RESPONSE:
583 		soisconnected(pcb->so);
584 		break;
585 	case VSOCK_RESET:
586 		pcb->so->so_error = EAGAIN;
587 		error = vsock_disconnect_pcb(pcb);
588 		break;
589 	default:
590 		vsock_disconnect_pcb(pcb);
591 		error = ENOTSUP;
592 		break;
593 	}
594 
595 	return error;
596 }
597 
598 static errno_t
vsock_put_message_listening(struct vsockpcb * _Nonnull pcb,enum vsock_operation op,struct vsock_address src,struct vsock_address dst)599 vsock_put_message_listening(struct vsockpcb *_Nonnull pcb, enum vsock_operation op, struct vsock_address src, struct vsock_address dst)
600 {
601 	VERIFY(pcb != NULL);
602 	socket_lock_assert_owned(pcb->so);
603 
604 	struct sockaddr_vm addr;
605 	struct socket *so2 = NULL;
606 	struct vsockpcb *pcb2 = NULL;
607 
608 	const uint16_t protocol = pcb->so->so_protocol;
609 
610 	errno_t error = 0;
611 
612 	switch (op) {
613 	case VSOCK_REQUEST:
614 		addr = (struct sockaddr_vm) {
615 			.svm_len = sizeof(addr),
616 			.svm_family = AF_VSOCK,
617 			.svm_reserved1 = 0,
618 			.svm_port = pcb->local_address.port,
619 			.svm_cid = pcb->local_address.cid
620 		};
621 		so2 = sonewconn(pcb->so, 0, (struct sockaddr *)&addr);
622 		if (!so2) {
623 			// It is likely that the backlog is full. Deny this request.
624 			vsock_pcb_safe_reset_address(pcb, dst, src, protocol);
625 			error = ECONNREFUSED;
626 			break;
627 		}
628 
629 		pcb2 = sotovsockpcb(so2);
630 		if (!pcb2) {
631 			error = EINVAL;
632 			goto done;
633 		}
634 
635 		error = vsock_bind_address(pcb2, dst, src);
636 		if (error) {
637 			goto done;
638 		}
639 
640 		error = vsock_pcb_respond(pcb2);
641 		if (error) {
642 			goto done;
643 		}
644 
645 		soisconnected(so2);
646 
647 done:
648 		if (error) {
649 			if (pcb2) {
650 				vsock_unbind_pcb(pcb2);
651 			} else {
652 				soisdisconnected(so2);
653 			}
654 			socket_unlock(so2, 1);
655 			vsock_pcb_reset_address(dst, src, protocol);
656 		} else {
657 			socket_unlock(so2, 0);
658 		}
659 		socket_lock(pcb->so, 0);
660 
661 		break;
662 	case VSOCK_RESET:
663 		error = vsock_pcb_safe_reset_address(pcb, dst, src, protocol);
664 		break;
665 	default:
666 		vsock_pcb_safe_reset_address(pcb, dst, src, protocol);
667 		error = ENOTSUP;
668 		break;
669 	}
670 
671 	return error;
672 }
673 
674 /* VSock Transport */
675 
676 errno_t
vsock_add_transport(struct vsock_transport * transport)677 vsock_add_transport(struct vsock_transport *transport)
678 {
679 	if (transport == NULL || transport->provider == NULL || transport->protocol >= VSOCK_PROTO_MAX) {
680 		return EINVAL;
681 	}
682 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport[transport->protocol], NULL, transport, acq_rel)) {
683 		return EEXIST;
684 	}
685 	return 0;
686 }
687 
688 errno_t
vsock_remove_transport(struct vsock_transport * transport)689 vsock_remove_transport(struct vsock_transport *transport)
690 {
691 	if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport[transport->protocol], transport, NULL, acq_rel)) {
692 		return ENODEV;
693 	}
694 	return 0;
695 }
696 
697 errno_t
vsock_reset_transport(struct vsock_transport * transport)698 vsock_reset_transport(struct vsock_transport *transport)
699 {
700 	if (transport == NULL) {
701 		return EINVAL;
702 	}
703 
704 	errno_t error = 0;
705 	struct vsockpcb *pcb = NULL;
706 	struct vsockpcb *tmp_pcb = NULL;
707 
708 	lck_rw_lock_exclusive(&vsockinfo[transport->protocol].bound_lock);
709 	LIST_FOREACH_SAFE(pcb, &vsockinfo[transport->protocol].bound, bound, tmp_pcb) {
710 		// Disconnect this transport's sockets. Listen and bind sockets must stay alive.
711 		socket_lock(pcb->so, 1);
712 		if (pcb->transport == transport && pcb->so->so_state & (SS_ISCONNECTED | SS_ISCONNECTING | SS_ISDISCONNECTING)) {
713 			errno_t dc_error = vsock_disconnect_pcb_locked(pcb);
714 			if (dc_error && !error) {
715 				error = dc_error;
716 			}
717 		}
718 		socket_unlock(pcb->so, 1);
719 	}
720 	lck_rw_done(&vsockinfo[transport->protocol].bound_lock);
721 
722 	return error;
723 }
724 
725 errno_t
vsock_put_message(struct vsock_address src,struct vsock_address dst,enum vsock_operation op,uint32_t buf_alloc,uint32_t fwd_cnt,mbuf_t m,uint16_t protocol)726 vsock_put_message(struct vsock_address src, struct vsock_address dst, enum vsock_operation op, uint32_t buf_alloc, uint32_t fwd_cnt, mbuf_t m, uint16_t protocol)
727 {
728 	struct vsockpcb *pcb = vsock_get_matching_pcb(dst, src, protocol);
729 	if (!pcb) {
730 		if (op != VSOCK_RESET) {
731 			vsock_pcb_reset_address(dst, src, protocol);
732 		}
733 		if (m != NULL) {
734 			mbuf_freem_list(m);
735 		}
736 		return EINVAL;
737 	}
738 
739 	socket_lock_assert_owned(pcb->so);
740 
741 	struct socket *so = pcb->so;
742 	errno_t error = 0;
743 
744 	// Check if the peer's buffer has changed. Update our view of the peer's forwarded bytes.
745 	int buffers_changed = (pcb->peer_buf_alloc != buf_alloc) || (pcb->peer_fwd_cnt) != fwd_cnt;
746 	pcb->peer_buf_alloc = buf_alloc;
747 	pcb->peer_fwd_cnt = fwd_cnt;
748 
749 	// Peer's buffer has enough space for the next packet. Notify any threads waiting for space.
750 	if (buffers_changed && vsock_get_peer_space(pcb) >= pcb->waiting_send_size) {
751 		sowwakeup(so);
752 	}
753 
754 	switch (op) {
755 	case VSOCK_CREDIT_REQUEST:
756 		error = vsock_pcb_credit_update(pcb);
757 		break;
758 	case VSOCK_CREDIT_UPDATE:
759 		break;
760 	default:
761 		if (so->so_state & SS_ISCONNECTED) {
762 			error = vsock_put_message_connected(pcb, op, m);
763 			m = NULL;
764 		} else if (so->so_state & SS_ISCONNECTING) {
765 			error = vsock_put_message_connecting(pcb, op);
766 		} else if (so->so_options & SO_ACCEPTCONN) {
767 			error = vsock_put_message_listening(pcb, op, src, dst);
768 		} else {
769 			// Reset the connection for other states such as 'disconnecting'.
770 			error = vsock_disconnect_pcb(pcb);
771 			if (!error) {
772 				error = ENODEV;
773 			}
774 		}
775 		break;
776 	}
777 	socket_unlock(so, 1);
778 
779 	if (m != NULL) {
780 		mbuf_freem_list(m);
781 	}
782 
783 	return error;
784 }
785 
786 /* VSock Sysctl */
787 
788 static int
common_vsock_pcblist(struct sysctl_oid * oidp __unused,void * arg1,int arg2 __unused,struct sysctl_req * _Nonnull req,uint16_t protocol)789 common_vsock_pcblist(struct sysctl_oid *oidp __unused, void *arg1, int arg2 __unused, struct sysctl_req *_Nonnull req, uint16_t protocol)
790 {
791     #pragma unused(oidp,arg2)
792 	VERIFY(req != NULL);
793 
794 	int error;
795 
796 	// Only stream is supported.
797 	if ((intptr_t)arg1 != SOCK_STREAM) {
798 		return EINVAL;
799 	}
800 
801 	// Get the generation count and the count of all vsock sockets.
802 	lck_rw_lock_shared(&vsockinfo[protocol].all_lock);
803 	uint64_t n = vsockinfo[protocol].all_pcb_count;
804 	vsock_gen_t gen_count = vsockinfo[protocol].vsock_gencnt;
805 	lck_rw_done(&vsockinfo[protocol].all_lock);
806 
807 	const size_t xpcb_len = sizeof(struct xvsockpcb);
808 	struct xvsockpgen xvg;
809 
810 	/*
811 	 * The process of preparing the PCB list is too time-consuming and
812 	 * resource-intensive to repeat twice on every request.
813 	 */
814 	if (req->oldptr == USER_ADDR_NULL) {
815 		req->oldidx = (size_t)(2 * sizeof(xvg) + (n + n / 8) * xpcb_len);
816 		return 0;
817 	}
818 
819 	if (req->newptr != USER_ADDR_NULL) {
820 		return EPERM;
821 	}
822 
823 	bzero(&xvg, sizeof(xvg));
824 	xvg.xvg_len = sizeof(xvg);
825 	xvg.xvg_count = n;
826 	xvg.xvg_gen = gen_count;
827 	xvg.xvg_sogen = so_gencnt;
828 	error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
829 	if (error) {
830 		return error;
831 	}
832 
833 	// Return if no sockets exist.
834 	if (n == 0) {
835 		return 0;
836 	}
837 
838 	lck_rw_lock_shared(&vsockinfo[protocol].all_lock);
839 
840 	n = 0;
841 	struct vsockpcb *pcb = NULL;
842 	TAILQ_FOREACH(pcb, &vsockinfo[protocol].all, all) {
843 		// Bail if there is not enough user buffer for this next socket.
844 		if (req->oldlen - req->oldidx - sizeof(xvg) < xpcb_len) {
845 			break;
846 		}
847 
848 		// Populate the socket structure.
849 		socket_lock(pcb->so, 1);
850 		if (pcb->vsock_gencnt <= gen_count) {
851 			struct xvsockpcb xpcb;
852 			bzero(&xpcb, xpcb_len);
853 			xpcb.xv_len = xpcb_len;
854 			xpcb.xv_vsockpp = (uint64_t)VM_KERNEL_ADDRHASH(pcb);
855 			xpcb.xvp_local_cid = pcb->local_address.cid;
856 			xpcb.xvp_local_port = pcb->local_address.port;
857 			xpcb.xvp_remote_cid = pcb->remote_address.cid;
858 			xpcb.xvp_remote_port = pcb->remote_address.port;
859 			xpcb.xvp_rxcnt = pcb->fwd_cnt;
860 			xpcb.xvp_txcnt = pcb->tx_cnt;
861 			xpcb.xvp_peer_rxhiwat = pcb->peer_buf_alloc;
862 			xpcb.xvp_peer_rxcnt = pcb->peer_fwd_cnt;
863 			xpcb.xvp_last_pid = pcb->so->last_pid;
864 			xpcb.xvp_gencnt = pcb->vsock_gencnt;
865 			if (pcb->so) {
866 				sotoxsocket(pcb->so, &xpcb.xv_socket);
867 			}
868 			socket_unlock(pcb->so, 1);
869 
870 			error = SYSCTL_OUT(req, &xpcb, xpcb_len);
871 			if (error != 0) {
872 				break;
873 			}
874 			n++;
875 		} else {
876 			socket_unlock(pcb->so, 1);
877 		}
878 	}
879 
880 	// Update the generation count to match the sockets being returned.
881 	gen_count = vsockinfo[protocol].vsock_gencnt;
882 
883 	lck_rw_done(&vsockinfo[protocol].all_lock);
884 
885 	if (!error) {
886 		/*
887 		 * Give the user an updated idea of our state.
888 		 * If the generation differs from what we told
889 		 * her before, she knows that something happened
890 		 * while we were processing this request, and it
891 		 * might be necessary to retry.
892 		 */
893 		bzero(&xvg, sizeof(xvg));
894 		xvg.xvg_len = sizeof(xvg);
895 		xvg.xvg_count = n;
896 		xvg.xvg_gen = gen_count;
897 		xvg.xvg_sogen = so_gencnt;
898 		error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
899 	}
900 
901 	return error;
902 }
903 
904 static int
905 vsock_pcblist SYSCTL_HANDLER_ARGS
906 {
907 	return common_vsock_pcblist(oidp, arg1, arg2, req, VSOCK_PROTO_STANDARD);
908 }
909 
910 static int
911 vsock_private_pcblist SYSCTL_HANDLER_ARGS
912 {
913 	return common_vsock_pcblist(oidp, arg1, arg2, req, VSOCK_PROTO_PRIVATE);
914 }
915 
916 #ifdef SYSCTL_DECL
917 // Standard namespace.
918 SYSCTL_NODE(_net, OID_AUTO, vsock, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock");
919 SYSCTL_UINT(_net_vsock, OID_AUTO, sendspace, CTLFLAG_RW | CTLFLAG_LOCKED,
920     &vsock_sendspace[VSOCK_PROTO_STANDARD], 0, "Maximum outgoing vsock datagram size");
921 SYSCTL_UINT(_net_vsock, OID_AUTO, recvspace, CTLFLAG_RW | CTLFLAG_LOCKED,
922     &vsock_recvspace[VSOCK_PROTO_STANDARD], 0, "Maximum incoming vsock datagram size");
923 SYSCTL_PROC(_net_vsock, OID_AUTO, pcblist,
924     CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
925     __unsafe_forge_single(caddr_t, SOCK_STREAM), 0, vsock_pcblist, "S,xvsockpcb",
926     "List of active vsock sockets");
927 SYSCTL_UINT(_net_vsock, OID_AUTO, pcbcount, CTLFLAG_RD | CTLFLAG_LOCKED,
928     (u_int *)&vsockinfo[VSOCK_PROTO_STANDARD].all_pcb_count, 0, "");
929 
930 // Private namespace.
931 SYSCTL_NODE(_net, OID_AUTO, vsock_private, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock_private");
932 SYSCTL_PROC(_net_vsock_private, OID_AUTO, pcblist,
933     CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
934     __unsafe_forge_single(caddr_t, SOCK_STREAM), 0, vsock_private_pcblist, "S,xvsockpcb",
935     "List of active private vsock sockets");
936 SYSCTL_UINT(_net_vsock_private, OID_AUTO, pcbcount, CTLFLAG_RD | CTLFLAG_LOCKED,
937     (u_int *)&vsockinfo[VSOCK_PROTO_PRIVATE].all_pcb_count, 0, "");
938 #endif
939 
940 /* VSock Protocol */
941 
942 static int
vsock_attach(struct socket * _Nonnull so,int proto,struct proc * p)943 vsock_attach(struct socket *_Nonnull so, int proto, struct proc *p)
944 {
945 	#pragma unused(proto, p)
946 	VERIFY(so != NULL);
947 
948 	const uint16_t protocol = so->so_protocol;
949 	if (protocol >= VSOCK_PROTO_MAX) {
950 		return EINVAL;
951 	}
952 
953 	errno_t error = vsock_validate_entitlements(protocol, p);
954 	if (error) {
955 		return error;
956 	}
957 
958 	const uint32_t send_space = vsock_sendspace[protocol];
959 	const uint32_t receive_space = vsock_recvspace[protocol];
960 	if (send_space == 0 || receive_space == 0) {
961 		return ENOMEM;
962 	}
963 
964 	// Reserve send and receive buffers.
965 	error = soreserve(so, send_space, receive_space);
966 	if (error) {
967 		return error;
968 	}
969 
970 	// Attach should only be run once per socket.
971 	struct vsockpcb *pcb = sotovsockpcb(so);
972 	if (pcb) {
973 		return EINVAL;
974 	}
975 
976 	// Get the transport for this socket.
977 	struct vsock_transport *transport = os_atomic_load(&the_vsock_transport[protocol], relaxed);
978 	if (transport == NULL) {
979 		return ENODEV;
980 	}
981 
982 	// Initialize the vsock protocol control block.
983 	pcb = zalloc_flags(vsockpcb_zone, Z_WAITOK | Z_ZERO | Z_NOFAIL);
984 	pcb->so = so;
985 	pcb->transport = transport;
986 	pcb->local_address = (struct vsock_address) {
987 		.cid = VMADDR_CID_ANY,
988 		.port = VMADDR_PORT_ANY
989 	};
990 	pcb->remote_address = (struct vsock_address) {
991 		.cid = VMADDR_CID_ANY,
992 		.port = VMADDR_PORT_ANY
993 	};
994 	so->so_pcb = pcb;
995 
996 	// Tell the transport that this socket has attached.
997 	error = transport->attach_socket(transport->provider);
998 	if (error) {
999 		zfree(vsockpcb_zone, pcb);
1000 		so->so_pcb = NULL;
1001 		return error;
1002 	}
1003 
1004 	// Add to the list of all vsock sockets.
1005 	lck_rw_lock_exclusive(&vsockinfo[protocol].all_lock);
1006 	TAILQ_INSERT_TAIL(&vsockinfo[protocol].all, pcb, all);
1007 	vsockinfo[protocol].all_pcb_count++;
1008 	pcb->vsock_gencnt = ++vsockinfo[protocol].vsock_gencnt;
1009 	lck_rw_done(&vsockinfo[protocol].all_lock);
1010 
1011 	return 0;
1012 }
1013 
1014 static int
vsock_control(struct socket * so,u_long cmd,caddr_t __sized_by (IOCPARM_LEN (cmd))data,struct ifnet * ifp,struct proc * p)1015 vsock_control(struct socket *so, u_long cmd, caddr_t __sized_by(IOCPARM_LEN(cmd)) data, struct ifnet *ifp, struct proc *p)
1016 {
1017 	#pragma unused(ifp, p)
1018 
1019 	VERIFY(so != NULL);
1020 
1021 	if (cmd != IOCTL_VM_SOCKETS_GET_LOCAL_CID) {
1022 		return EINVAL;
1023 	}
1024 
1025 	if (so == NULL) {
1026 		return EINVAL;
1027 	}
1028 
1029 	struct vsockpcb *pcb = sotovsockpcb(so);
1030 	if (pcb == NULL) {
1031 		return EINVAL;
1032 	}
1033 
1034 	struct vsock_transport *transport = pcb->transport;
1035 	if (transport == NULL) {
1036 		return ENODEV;
1037 	}
1038 
1039 	uint32_t transport_cid;
1040 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
1041 	if (error) {
1042 		return error;
1043 	}
1044 
1045 	memcpy(data, &transport_cid, sizeof(transport_cid));
1046 
1047 	return 0;
1048 }
1049 
1050 static int
vsock_detach(struct socket * so)1051 vsock_detach(struct socket *so)
1052 {
1053 	struct vsockpcb *pcb = sotovsockpcb(so);
1054 	if (pcb == NULL) {
1055 		return EINVAL;
1056 	}
1057 
1058 	vsock_unbind_pcb(pcb);
1059 
1060 	// Tell the transport that this socket has detached.
1061 	struct vsock_transport *transport = pcb->transport;
1062 	errno_t error = transport->detach_socket(transport->provider);
1063 	if (error) {
1064 		return error;
1065 	}
1066 
1067 	const uint16_t protocol = so->so_protocol;
1068 
1069 	// Mark this socket for deallocation.
1070 	so->so_flags |= SOF_PCBCLEARING;
1071 
1072 	// Reorder locks.
1073 	socket_unlock(so, 0);
1074 	lck_rw_lock_exclusive(&vsockinfo[protocol].all_lock);
1075 	socket_lock(so, 0);
1076 
1077 	// Remove from the list of all vsock sockets.
1078 	TAILQ_REMOVE(&vsockinfo[protocol].all, pcb, all);
1079 	pcb->all.tqe_next = NULL;
1080 	pcb->all.tqe_prev = NULL;
1081 	vsockinfo[protocol].all_pcb_count--;
1082 	vsockinfo[protocol].vsock_gencnt++;
1083 	lck_rw_done(&vsockinfo[protocol].all_lock);
1084 
1085 	return 0;
1086 }
1087 
1088 static int
vsock_abort(struct socket * so)1089 vsock_abort(struct socket *so)
1090 {
1091 	return vsock_detach(so);
1092 }
1093 
1094 static int
vsock_bind(struct socket * so,struct sockaddr * nam,struct proc * p)1095 vsock_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
1096 {
1097 	#pragma unused(p)
1098 
1099 	struct vsockpcb *pcb = sotovsockpcb(so);
1100 	if (pcb == NULL) {
1101 		return EINVAL;
1102 	}
1103 
1104 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1105 
1106 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr, p);
1107 	if (error) {
1108 		return error;
1109 	}
1110 
1111 	struct vsock_address laddr = (struct vsock_address) {
1112 		.cid = addr->svm_cid,
1113 		.port = addr->svm_port,
1114 	};
1115 
1116 	struct vsock_address raddr = (struct vsock_address) {
1117 		.cid = VMADDR_CID_ANY,
1118 		.port = VMADDR_PORT_ANY,
1119 	};
1120 
1121 	error = vsock_bind_address(pcb, laddr, raddr);
1122 	if (error) {
1123 		return error;
1124 	}
1125 
1126 	return 0;
1127 }
1128 
1129 static int
vsock_listen(struct socket * so,struct proc * p)1130 vsock_listen(struct socket *so, struct proc *p)
1131 {
1132 	#pragma unused(p)
1133 
1134 	struct vsockpcb *pcb = sotovsockpcb(so);
1135 	if (pcb == NULL) {
1136 		return EINVAL;
1137 	}
1138 
1139 	// Only stream is supported currently.
1140 	if (so->so_type != SOCK_STREAM) {
1141 		return EAFNOSUPPORT;
1142 	}
1143 
1144 	struct vsock_address *addr = &pcb->local_address;
1145 
1146 	if (addr->port == VMADDR_CID_ANY) {
1147 		return EFAULT;
1148 	}
1149 
1150 	struct vsock_transport *transport = pcb->transport;
1151 	uint32_t transport_cid;
1152 	errno_t error = transport->get_cid(transport->provider, &transport_cid);
1153 	if (error) {
1154 		return error;
1155 	}
1156 
1157 	// Can listen on the transport's cid or any.
1158 	if (addr->cid != transport_cid && addr->cid != VMADDR_CID_ANY) {
1159 		return EFAULT;
1160 	}
1161 
1162 	return 0;
1163 }
1164 
1165 static int
vsock_accept(struct socket * so,struct sockaddr ** nam)1166 vsock_accept(struct socket *so, struct sockaddr **nam)
1167 {
1168 	struct vsockpcb *pcb = sotovsockpcb(so);
1169 	if (pcb == NULL) {
1170 		return EINVAL;
1171 	}
1172 
1173 	// Do not accept disconnected sockets.
1174 	if (so->so_state & SS_ISDISCONNECTED) {
1175 		return ECONNABORTED;
1176 	}
1177 
1178 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1179 
1180 	return 0;
1181 }
1182 
1183 static int
vsock_connect(struct socket * so,struct sockaddr * nam,struct proc * p)1184 vsock_connect(struct socket *so, struct sockaddr *nam, struct proc *p)
1185 {
1186 	#pragma unused(p)
1187 
1188 	struct vsockpcb *pcb = sotovsockpcb(so);
1189 	if (pcb == NULL) {
1190 		return EINVAL;
1191 	}
1192 
1193 	struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1194 
1195 	errno_t error = vsock_sockaddr_vm_validate(pcb, addr, p);
1196 	if (error) {
1197 		return error;
1198 	}
1199 
1200 	uint32_t transport_cid;
1201 	struct vsock_transport *transport = pcb->transport;
1202 	error = transport->get_cid(transport->provider, &transport_cid);
1203 	if (error) {
1204 		return error;
1205 	}
1206 
1207 	// Only supporting connections to the host, hypervisor, or self for now.
1208 	if (addr->svm_cid != VMADDR_CID_HOST &&
1209 	    addr->svm_cid != VMADDR_CID_HYPERVISOR &&
1210 	    addr->svm_cid != transport_cid) {
1211 		return EFAULT;
1212 	}
1213 
1214 	soisconnecting(so);
1215 
1216 	// Set the remote and local address.
1217 	struct vsock_address remote_addr = (struct vsock_address) {
1218 		.cid = addr->svm_cid,
1219 		.port = addr->svm_port,
1220 	};
1221 
1222 	struct vsock_address local_addr = (struct vsock_address) {
1223 		.cid = transport_cid,
1224 		.port = VMADDR_PORT_ANY,
1225 	};
1226 
1227 	// Bind to the address.
1228 	error = vsock_bind_address(pcb, local_addr, remote_addr);
1229 	if (error) {
1230 		goto cleanup;
1231 	}
1232 
1233 	// Attempt a connection using the socket's transport.
1234 	error = vsock_pcb_connect(pcb);
1235 	if (error) {
1236 		goto cleanup;
1237 	}
1238 
1239 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1240 		// Don't wait for peer's response if non-blocking.
1241 		if (so->so_state & SS_NBIO) {
1242 			goto done;
1243 		}
1244 
1245 		struct timespec ts = (struct timespec) {
1246 			.tv_sec = so->so_snd.sb_timeo.tv_sec,
1247 			.tv_nsec = so->so_snd.sb_timeo.tv_usec * 1000,
1248 		};
1249 
1250 		lck_mtx_t *mutex_held;
1251 		if (so->so_proto->pr_getlock != NULL) {
1252 			mutex_held = (*so->so_proto->pr_getlock)(so, PR_F_WILLUNLOCK);
1253 		} else {
1254 			mutex_held = so->so_proto->pr_domain->dom_mtx;
1255 		}
1256 
1257 		// Wait until we receive a response to the connect request.
1258 		error = msleep((caddr_t)&so->so_timeo, mutex_held, PSOCK | PCATCH, "vsock_connect", &ts);
1259 		if (error) {
1260 			if (error == EAGAIN) {
1261 				error = ETIMEDOUT;
1262 			}
1263 			goto cleanup;
1264 		}
1265 	}
1266 
1267 cleanup:
1268 	if (so->so_error && !error) {
1269 		error = so->so_error;
1270 		so->so_error = 0;
1271 	}
1272 	if (!error) {
1273 		error = !(so->so_state & SS_ISCONNECTED);
1274 	}
1275 	if (error) {
1276 		vsock_unbind_pcb(pcb);
1277 	}
1278 
1279 done:
1280 	return error;
1281 }
1282 
1283 static int
vsock_disconnect(struct socket * so)1284 vsock_disconnect(struct socket *so)
1285 {
1286 	struct vsockpcb *pcb = sotovsockpcb(so);
1287 	if (pcb == NULL) {
1288 		return EINVAL;
1289 	}
1290 
1291 	return vsock_disconnect_pcb(pcb);
1292 }
1293 
1294 static int
vsock_sockaddr(struct socket * so,struct sockaddr ** nam)1295 vsock_sockaddr(struct socket *so, struct sockaddr **nam)
1296 {
1297 	struct vsockpcb *pcb = sotovsockpcb(so);
1298 	if (pcb == NULL) {
1299 		return EINVAL;
1300 	}
1301 
1302 	*nam = vsock_new_sockaddr(&pcb->local_address);
1303 
1304 	return 0;
1305 }
1306 
1307 static int
vsock_peeraddr(struct socket * so,struct sockaddr ** nam)1308 vsock_peeraddr(struct socket *so, struct sockaddr **nam)
1309 {
1310 	struct vsockpcb *pcb = sotovsockpcb(so);
1311 	if (pcb == NULL) {
1312 		return EINVAL;
1313 	}
1314 
1315 	*nam = vsock_new_sockaddr(&pcb->remote_address);
1316 
1317 	return 0;
1318 }
1319 
1320 static int
vsock_send(struct socket * so,int flags,struct mbuf * m,struct sockaddr * nam,struct mbuf * control,proc_t p)1321 vsock_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *nam, struct mbuf *control, proc_t p)
1322 {
1323 	#pragma unused(flags, nam, p)
1324 
1325 	errno_t error = 0;
1326 	struct vsockpcb *pcb = sotovsockpcb(so);
1327 	if (pcb == NULL || m == NULL) {
1328 		error = EINVAL;
1329 		goto out;
1330 	}
1331 
1332 	if (control != NULL) {
1333 		error = EOPNOTSUPP;
1334 		goto out;
1335 	}
1336 
1337 	// Ensure this socket is connected.
1338 	if ((so->so_state & SS_ISCONNECTED) == 0) {
1339 		error = EPERM;
1340 		goto out;
1341 	}
1342 
1343 	// rdar://84098487 (SEED: Web: Virtio-socket sent data lost after 128KB)
1344 	// For writes larger than the default `sosendmaxchain` of 65536, vsock_send() is called multiple times per write().
1345 	// Only the first call to vsock_send() is passed a valid mbuf packet, while subsequent calls are not marked as a packet
1346 	// with a valid length. We should mark all mbufs as a packet and set the correct packet length so that the downstream
1347 	// socket transport layer can correctly generate physical segments.
1348 	if (!(mbuf_flags(m) & MBUF_PKTHDR)) {
1349 		if (!(mbuf_flags(m) & M_EXT)) {
1350 			struct mbuf *header = NULL;
1351 			MGETHDR(header, M_WAITOK, MT_HEADER);
1352 			if (header == NULL) {
1353 				error = ENOBUFS;
1354 				goto out;
1355 			}
1356 			header->m_next = m;
1357 			m = header;
1358 		} else {
1359 			mbuf_setflags(m, mbuf_flags(m) | MBUF_PKTHDR);
1360 		}
1361 
1362 		size_t len = 0;
1363 		struct mbuf *next = m;
1364 		while (next) {
1365 			len += mbuf_len(next);
1366 			next = mbuf_next(next);
1367 		}
1368 		mbuf_pkthdr_setlen(m, len);
1369 	}
1370 
1371 	const size_t len = mbuf_pkthdr_len(m);
1372 	uint32_t free_space = vsock_get_peer_space(pcb);
1373 
1374 	// Ensure the peer has enough space in their receive buffer.
1375 	while (len > free_space) {
1376 		// Record the number of free peer bytes necessary before we can send.
1377 		if (len > pcb->waiting_send_size) {
1378 			pcb->waiting_send_size = len;
1379 		}
1380 
1381 		// Send a credit request.
1382 		error = vsock_pcb_credit_request(pcb);
1383 		if (error) {
1384 			goto out;
1385 		}
1386 
1387 		// Check again in case free space was automatically updated in loopback case.
1388 		free_space = vsock_get_peer_space(pcb);
1389 		if (len <= free_space) {
1390 			pcb->waiting_send_size = 0;
1391 			break;
1392 		}
1393 
1394 		// Bail if this is a non-blocking socket.
1395 		if (so->so_state & SS_NBIO) {
1396 			error = EWOULDBLOCK;
1397 			goto out;
1398 		}
1399 
1400 		// Wait until our peer has enough free space in their receive buffer.
1401 		error = sbwait(&so->so_snd);
1402 		pcb->waiting_send_size = 0;
1403 		if (error) {
1404 			goto out;
1405 		}
1406 
1407 		// Bail if an error occured or we can't send more.
1408 		if (so->so_state & SS_CANTSENDMORE) {
1409 			error = EPIPE;
1410 			goto out;
1411 		} else if (so->so_error) {
1412 			error = so->so_error;
1413 			so->so_error = 0;
1414 			goto out;
1415 		}
1416 
1417 		free_space = vsock_get_peer_space(pcb);
1418 	}
1419 
1420 	// Send a payload over the transport.
1421 	error = vsock_pcb_send(pcb, m);
1422 	if (error) {
1423 		return error;
1424 	}
1425 
1426 	pcb->tx_cnt += len;
1427 
1428 	return 0;
1429 
1430 out:
1431 	if (control != NULL) {
1432 		m_freem(control);
1433 	}
1434 	if (m != NULL) {
1435 		mbuf_freem_list(m);
1436 	}
1437 	return error;
1438 }
1439 
1440 static int
vsock_shutdown(struct socket * so)1441 vsock_shutdown(struct socket *so)
1442 {
1443 	struct vsockpcb *pcb = sotovsockpcb(so);
1444 	if (pcb == NULL) {
1445 		return EINVAL;
1446 	}
1447 
1448 	socantsendmore(so);
1449 
1450 	// Tell peer we will no longer send.
1451 	errno_t error = vsock_pcb_shutdown_send(pcb);
1452 	if (error) {
1453 		return error;
1454 	}
1455 
1456 	return 0;
1457 }
1458 
1459 static int
vsock_soreceive(struct socket * so,struct sockaddr ** psa,struct uio * uio,struct mbuf ** mp0,struct mbuf ** controlp,int * flagsp)1460 vsock_soreceive(struct socket *so, struct sockaddr **psa, struct uio *uio,
1461     struct mbuf **mp0, struct mbuf **controlp, int *flagsp)
1462 {
1463 	struct vsockpcb *pcb = sotovsockpcb(so);
1464 	if (pcb == NULL) {
1465 		return EINVAL;
1466 	}
1467 
1468 	user_ssize_t length = uio_resid(uio);
1469 	int result = soreceive(so, psa, uio, mp0, controlp, flagsp);
1470 	length -= uio_resid(uio);
1471 
1472 	socket_lock(so, 1);
1473 
1474 	pcb->fwd_cnt += length;
1475 
1476 	const uint32_t threshold = VSOCK_MAX_PACKET_SIZE;
1477 
1478 	// Send a credit update if it is possible that the peer will no longer send.
1479 	if ((pcb->fwd_cnt - pcb->last_fwd_cnt + threshold) >= pcb->last_buf_alloc) {
1480 		errno_t error = vsock_pcb_credit_update_if_needed(pcb);
1481 		if (!result && error) {
1482 			result = error;
1483 		}
1484 	}
1485 
1486 	socket_unlock(so, 1);
1487 
1488 	return result;
1489 }
1490 
1491 static struct pr_usrreqs vsock_usrreqs = {
1492 	.pru_abort =            vsock_abort,
1493 	.pru_attach =           vsock_attach,
1494 	.pru_control =          vsock_control,
1495 	.pru_detach =           vsock_detach,
1496 	.pru_bind =             vsock_bind,
1497 	.pru_listen =           vsock_listen,
1498 	.pru_accept =           vsock_accept,
1499 	.pru_connect =          vsock_connect,
1500 	.pru_disconnect =       vsock_disconnect,
1501 	.pru_send =             vsock_send,
1502 	.pru_shutdown =         vsock_shutdown,
1503 	.pru_sockaddr =         vsock_sockaddr,
1504 	.pru_peeraddr =         vsock_peeraddr,
1505 	.pru_sosend =           sosend,
1506 	.pru_soreceive =        vsock_soreceive,
1507 };
1508 
1509 static void
common_vsock_init(struct protosw * pp,struct domain * dp,uint16_t protocol,lck_grp_t * lock_group)1510 common_vsock_init(struct protosw *pp, struct domain *dp, uint16_t protocol, lck_grp_t *lock_group)
1511 {
1512 	#pragma unused(dp)
1513 
1514 	static int vsock_initialized[VSOCK_PROTO_MAX] = {0};
1515 	VERIFY((pp->pr_flags & (PR_INITIALIZED | PR_ATTACHED)) == PR_ATTACHED);
1516 	if (!os_atomic_cmpxchg((volatile int *)&vsock_initialized[protocol], 0, 1, acq_rel)) {
1517 		return;
1518 	}
1519 
1520 	// Setup VSock protocol info struct.
1521 	lck_rw_init(&vsockinfo[protocol].all_lock, lock_group, LCK_ATTR_NULL);
1522 	lck_rw_init(&vsockinfo[protocol].bound_lock, lock_group, LCK_ATTR_NULL);
1523 	lck_mtx_init(&vsockinfo[protocol].port_lock, lock_group, LCK_ATTR_NULL);
1524 	TAILQ_INIT(&vsockinfo[protocol].all);
1525 	LIST_INIT(&vsockinfo[protocol].bound);
1526 	vsockinfo[protocol].last_port = VMADDR_PORT_ANY;
1527 }
1528 
1529 static void
vsock_init(struct protosw * pp,struct domain * dp)1530 vsock_init(struct protosw *pp, struct domain *dp)
1531 {
1532 	static LCK_GRP_DECLARE(vsock_lock_grp, "vsock");
1533 	common_vsock_init(pp, dp, VSOCK_PROTO_STANDARD, &vsock_lock_grp);
1534 }
1535 
1536 static void
vsock_private_init(struct protosw * pp,struct domain * dp)1537 vsock_private_init(struct protosw *pp, struct domain *dp)
1538 {
1539 	static LCK_GRP_DECLARE(vsock_private_lock_grp, "vsock_private");
1540 	common_vsock_init(pp, dp, VSOCK_PROTO_PRIVATE, &vsock_private_lock_grp);
1541 }
1542 
1543 static int
vsock_sofreelastref(struct socket * so,int dealloc)1544 vsock_sofreelastref(struct socket *so, int dealloc)
1545 {
1546 	socket_lock_assert_owned(so);
1547 
1548 	struct vsockpcb *pcb = sotovsockpcb(so);
1549 	if (pcb != NULL) {
1550 		zfree(vsockpcb_zone, pcb);
1551 	}
1552 
1553 	so->so_pcb = NULL;
1554 	sofreelastref(so, dealloc);
1555 
1556 	return 0;
1557 }
1558 
1559 static int
vsock_unlock(struct socket * _Nonnull so,int refcount,void * lr_saved)1560 vsock_unlock(struct socket *_Nonnull so, int refcount, void *lr_saved)
1561 {
1562 	VERIFY(so != NULL);
1563 
1564 	lck_mtx_t *mutex_held = so->so_proto->pr_domain->dom_mtx;
1565 #ifdef MORE_LOCKING_DEBUG
1566 	LCK_MTX_ASSERT(mutex_held, LCK_MTX_ASSERT_OWNED);
1567 #endif
1568 	so->unlock_lr[so->next_unlock_lr] = lr_saved;
1569 	so->next_unlock_lr = (so->next_unlock_lr + 1) % SO_LCKDBG_MAX;
1570 
1571 	if (refcount) {
1572 		if (so->so_usecount <= 0) {
1573 			panic("%s: bad refcount=%d so=%p (%d, %d, %d) "
1574 			    "lrh=%s", __func__, so->so_usecount, so,
1575 			    SOCK_DOM(so), so->so_type,
1576 			    SOCK_PROTO(so), solockhistory_nr(so));
1577 			/* NOTREACHED */
1578 		}
1579 
1580 		so->so_usecount--;
1581 		if (so->so_usecount == 0) {
1582 			vsock_sofreelastref(so, 1);
1583 		}
1584 	}
1585 	lck_mtx_unlock(mutex_held);
1586 
1587 	return 0;
1588 }
1589 
1590 static struct protosw vsocksw[VSOCK_PROTO_MAX] = {
1591 	{
1592 		.pr_type =              SOCK_STREAM,
1593 		.pr_protocol =          VSOCK_PROTO_STANDARD,
1594 		.pr_flags =             PR_CONNREQUIRED | PR_WANTRCVD,
1595 		.pr_init =              vsock_init,
1596 		.pr_unlock =            vsock_unlock,
1597 		.pr_usrreqs =           &vsock_usrreqs,
1598 	},
1599 	{
1600 		.pr_type =              SOCK_STREAM,
1601 		.pr_protocol =          VSOCK_PROTO_PRIVATE,
1602 		.pr_flags =             PR_CONNREQUIRED | PR_WANTRCVD,
1603 		.pr_init =              vsock_private_init,
1604 		.pr_unlock =            vsock_unlock,
1605 		.pr_usrreqs =           &vsock_usrreqs,
1606 	}
1607 };
1608 
1609 static const int vsock_proto_count = (sizeof(vsocksw) / sizeof(struct protosw));
1610 
1611 /* VSock Domain */
1612 
1613 static struct domain *vsock_domain = NULL;
1614 
1615 static void
vsock_dinit(struct domain * _Nonnull dp)1616 vsock_dinit(struct domain *_Nonnull dp)
1617 {
1618 	// The VSock domain is initialized with a singleton pattern.
1619 	VERIFY(dp != NULL);
1620 	VERIFY(!(dp->dom_flags & DOM_INITIALIZED));
1621 	VERIFY(vsock_domain == NULL);
1622 	vsock_domain = dp;
1623 
1624 	const uint32_t default_buffer_size = VSOCK_MAX_PACKET_SIZE * 8;
1625 
1626 	// Add protocols and initialize.
1627 	for (int i = 0; i < vsock_proto_count; i++) {
1628 		vsock_sendspace[i] = default_buffer_size;
1629 		vsock_recvspace[i] = default_buffer_size;
1630 
1631 		net_add_proto((struct protosw *)&vsocksw[i], dp, 1);
1632 	}
1633 }
1634 
1635 struct domain vsockdomain_s = {
1636 	.dom_family =           PF_VSOCK,
1637 	.dom_name =             "vsock",
1638 	.dom_init =             vsock_dinit,
1639 	.dom_maxrtkey =         sizeof(struct sockaddr_vm),
1640 	.dom_protohdrlen =      sizeof(struct sockaddr_vm),
1641 };
1642