1 /*
2 * Copyright (c) 2020 Apple Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28
29 #include <sys/domain.h>
30 #include <sys/socket.h>
31 #include <sys/protosw.h>
32 #include <sys/mcache.h>
33 #include <sys/systm.h>
34 #include <sys/sysctl.h>
35 #include <sys/random.h>
36 #include <sys/mbuf.h>
37 #include <sys/vsock_domain.h>
38 #include <sys/vsock_transport.h>
39 #include <kern/task.h>
40 #include <kern/zalloc.h>
41 #include <kern/locks.h>
42 #include <machine/atomic.h>
43 #include <IOKit/IOBSD.h>
44
45 #define sotovsockpcb(so) ((struct vsockpcb *)(so)->so_pcb)
46
47 #define VSOCK_PORT_RESERVED 1024
48 #define VSOCK_PRIVATE_ENTITLEMENT "com.apple.private.vsock"
49
50 /* VSock Protocol Globals */
51
52 static struct vsock_transport * _Atomic the_vsock_transport[VSOCK_PROTO_MAX];
53 static ZONE_DEFINE_TYPE(vsockpcb_zone, "vsockpcbzone", struct vsockpcb, ZC_NONE);
54 static struct vsockpcbinfo vsockinfo[VSOCK_PROTO_MAX];
55
56 static uint32_t vsock_sendspace[VSOCK_PROTO_MAX];
57 static uint32_t vsock_recvspace[VSOCK_PROTO_MAX];
58
59 /* VSock Private Entitlements */
60
61 static errno_t
vsock_validate_entitlements(uint16_t protocol,struct proc * p)62 vsock_validate_entitlements(uint16_t protocol, struct proc *p)
63 {
64 if (protocol != VSOCK_PROTO_PRIVATE) {
65 return 0;
66 }
67
68 if (!p) {
69 p = current_proc();
70 }
71
72 if (p == kernproc) {
73 // Assume kernel callers are entitled.
74 return 0;
75 }
76
77 if (!IOTaskHasEntitlement(proc_task(p), VSOCK_PRIVATE_ENTITLEMENT)) {
78 return EPERM;
79 }
80
81 return 0;
82 }
83
84 /* VSock PCB Helpers */
85
86 static uint32_t
vsock_get_peer_space(struct vsockpcb * _Nonnull pcb)87 vsock_get_peer_space(struct vsockpcb *_Nonnull pcb)
88 {
89 VERIFY(pcb != NULL);
90 return pcb->peer_buf_alloc - (pcb->tx_cnt - pcb->peer_fwd_cnt);
91 }
92
93 static struct vsockpcb *
vsock_get_matching_pcb(struct vsock_address src,struct vsock_address dst,uint16_t protocol)94 vsock_get_matching_pcb(struct vsock_address src, struct vsock_address dst, uint16_t protocol)
95 {
96 struct vsockpcb *preferred = NULL;
97 struct vsockpcb *match = NULL;
98 struct vsockpcb *pcb = NULL;
99
100 lck_rw_lock_shared(&vsockinfo[protocol].bound_lock);
101 LIST_FOREACH(pcb, &vsockinfo[protocol].bound, bound) {
102 // Source cid and port must match. Only destination port must match. (Allows for a changing CID during migration)
103 socket_lock(pcb->so, 1);
104 if ((pcb->so->so_state & SS_ISCONNECTED || pcb->so->so_state & SS_ISCONNECTING) &&
105 pcb->local_address.cid == src.cid && pcb->local_address.port == src.port &&
106 pcb->remote_address.port == dst.port) {
107 preferred = pcb;
108 break;
109 } else if ((pcb->local_address.cid == src.cid || pcb->local_address.cid == VMADDR_CID_ANY) &&
110 pcb->local_address.port == src.port) {
111 match = pcb;
112 }
113 socket_unlock(pcb->so, 1);
114 }
115 if (!preferred && match) {
116 socket_lock(match->so, 1);
117 preferred = match;
118 }
119 lck_rw_done(&vsockinfo[protocol].bound_lock);
120
121 return preferred;
122 }
123
124 static errno_t
vsock_bind_address_if_free(struct vsockpcb * _Nonnull pcb,uint32_t local_cid,uint32_t local_port,uint32_t remote_cid,uint32_t remote_port)125 vsock_bind_address_if_free(struct vsockpcb *_Nonnull pcb, uint32_t local_cid, uint32_t local_port, uint32_t remote_cid, uint32_t remote_port)
126 {
127 VERIFY(pcb != NULL);
128 socket_lock_assert_owned(pcb->so);
129
130 // Privileged ports.
131 if (local_port != VMADDR_PORT_ANY && local_port < VSOCK_PORT_RESERVED &&
132 current_task() != kernel_task && proc_suser(current_proc()) != 0) {
133 return EACCES;
134 }
135
136 bool taken = false;
137 const bool check_remote = (remote_cid != VMADDR_CID_ANY && remote_port != VMADDR_PORT_ANY);
138 const uint16_t protocol = pcb->so->so_protocol;
139
140 struct vsockpcb *pcb_match = NULL;
141
142 socket_unlock(pcb->so, 0);
143 lck_rw_lock_exclusive(&vsockinfo[protocol].bound_lock);
144 LIST_FOREACH(pcb_match, &vsockinfo[protocol].bound, bound) {
145 socket_lock(pcb_match->so, 1);
146 if (pcb == pcb_match ||
147 (!check_remote && pcb_match->local_address.port == local_port) ||
148 (check_remote && pcb_match->local_address.port == local_port &&
149 pcb_match->remote_address.cid == remote_cid && pcb_match->remote_address.port == remote_port)) {
150 socket_unlock(pcb_match->so, 1);
151 taken = true;
152 break;
153 }
154 socket_unlock(pcb_match->so, 1);
155 }
156 socket_lock(pcb->so, 0);
157 if (!taken) {
158 pcb->local_address = (struct vsock_address) { .cid = local_cid, .port = local_port };
159 pcb->remote_address = (struct vsock_address) { .cid = remote_cid, .port = remote_port };
160 LIST_INSERT_HEAD(&vsockinfo[protocol].bound, pcb, bound);
161 }
162 lck_rw_done(&vsockinfo[protocol].bound_lock);
163
164 return taken ? EADDRINUSE : 0;
165 }
166
167 static errno_t
vsock_bind_address(struct vsockpcb * pcb,struct vsock_address laddr,struct vsock_address raddr)168 vsock_bind_address(struct vsockpcb *pcb, struct vsock_address laddr, struct vsock_address raddr)
169 {
170 if (!pcb) {
171 return EINVAL;
172 }
173
174 socket_lock_assert_owned(pcb->so);
175
176 // Certain CIDs are reserved.
177 if (laddr.cid == VMADDR_CID_HYPERVISOR || laddr.cid == VMADDR_CID_RESERVED || laddr.cid == VMADDR_CID_HOST) {
178 return EADDRNOTAVAIL;
179 }
180
181 // Remote address must be fully specified or not specified at all.
182 if ((raddr.cid == VMADDR_CID_ANY) ^ (raddr.port == VMADDR_PORT_ANY)) {
183 return EINVAL;
184 }
185
186 // Cannot bind if already bound.
187 if (pcb->local_address.port != VMADDR_PORT_ANY) {
188 return EINVAL;
189 }
190
191 uint32_t transport_cid;
192 struct vsock_transport *transport = pcb->transport;
193 errno_t error = transport->get_cid(transport->provider, &transport_cid);
194 if (error) {
195 return error;
196 }
197
198 // Local CID must be this transport's CID or any.
199 if (laddr.cid != transport_cid && laddr.cid != VMADDR_CID_ANY) {
200 return EINVAL;
201 }
202
203 if (laddr.port != VMADDR_PORT_ANY) {
204 error = vsock_bind_address_if_free(pcb, laddr.cid, laddr.port, raddr.cid, raddr.port);
205 } else {
206 const uint16_t protocol = pcb->so->so_protocol;
207
208 socket_unlock(pcb->so, 0);
209 lck_mtx_lock(&vsockinfo[protocol].port_lock);
210 socket_lock(pcb->so, 0);
211
212 const uint32_t first = VSOCK_PORT_RESERVED;
213 const uint32_t last = VMADDR_PORT_ANY - 1;
214 uint32_t count = last - first + 1;
215 uint32_t *last_port = &vsockinfo[protocol].last_port;
216
217 if (pcb->so->so_flags & SOF_BINDRANDOMPORT) {
218 uint32_t random = 0;
219 read_frandom(&random, sizeof(random));
220 *last_port = first + (random % count);
221 }
222
223 do {
224 if (count == 0) {
225 lck_mtx_unlock(&vsockinfo[protocol].port_lock);
226 return EADDRNOTAVAIL;
227 }
228 count--;
229
230 ++*last_port;
231 if (*last_port < first || *last_port > last) {
232 *last_port = first;
233 }
234
235 error = vsock_bind_address_if_free(pcb, laddr.cid, *last_port, raddr.cid, raddr.port);
236 } while (error);
237
238 lck_mtx_unlock(&vsockinfo[protocol].port_lock);
239 }
240
241 return error;
242 }
243
244 static void
vsock_unbind_pcb_locked(struct vsockpcb * pcb,bool is_locked)245 vsock_unbind_pcb_locked(struct vsockpcb *pcb, bool is_locked)
246 {
247 if (!pcb) {
248 return;
249 }
250
251 struct socket *so = pcb->so;
252 socket_lock_assert_owned(so);
253
254 // Bail if disconnect and already unbound.
255 if (so->so_state & SS_ISDISCONNECTED) {
256 assert(pcb->bound.le_next == NULL);
257 assert(pcb->bound.le_prev == NULL);
258 return;
259 }
260
261 const uint16_t protocol = so->so_protocol;
262
263 if (!is_locked) {
264 socket_unlock(so, 0);
265 lck_rw_lock_exclusive(&vsockinfo[protocol].bound_lock);
266 socket_lock(so, 0);
267
268 // Case where some other thread also called unbind() on this socket while waiting to acquire its lock.
269 if (!pcb->bound.le_prev) {
270 soisdisconnected(so);
271 lck_rw_done(&vsockinfo[protocol].bound_lock);
272 return;
273 }
274 }
275
276 soisdisconnected(so);
277
278 LIST_REMOVE(pcb, bound);
279 pcb->bound.le_next = NULL;
280 pcb->bound.le_prev = NULL;
281
282 if (!is_locked) {
283 lck_rw_done(&vsockinfo[protocol].bound_lock);
284 }
285 }
286
287 static void
vsock_unbind_pcb(struct vsockpcb * pcb)288 vsock_unbind_pcb(struct vsockpcb *pcb)
289 {
290 vsock_unbind_pcb_locked(pcb, false);
291 }
292
293 static struct sockaddr *
vsock_new_sockaddr(struct vsock_address * address)294 vsock_new_sockaddr(struct vsock_address *address)
295 {
296 if (!address) {
297 return NULL;
298 }
299
300 struct sockaddr_vm *addr;
301 addr = (struct sockaddr_vm *)alloc_sockaddr(sizeof(*addr),
302 Z_WAITOK | Z_NOFAIL);
303
304 addr->svm_family = AF_VSOCK;
305 addr->svm_port = address->port;
306 addr->svm_cid = address->cid;
307
308 return (struct sockaddr *)addr;
309 }
310
311 static errno_t
vsock_pcb_send_message(struct vsockpcb * pcb,enum vsock_operation operation,mbuf_t m)312 vsock_pcb_send_message(struct vsockpcb *pcb, enum vsock_operation operation, mbuf_t m)
313 {
314 if (!pcb) {
315 if (m != NULL) {
316 mbuf_freem_list(m);
317 }
318 return EINVAL;
319 }
320
321 socket_lock_assert_owned(pcb->so);
322
323 errno_t error;
324
325 struct vsock_address dst = pcb->remote_address;
326 if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
327 if (m != NULL) {
328 mbuf_freem_list(m);
329 }
330 return EINVAL;
331 }
332
333 struct vsock_address src = pcb->local_address;
334 if (src.cid == VMADDR_CID_ANY) {
335 uint32_t transport_cid;
336 struct vsock_transport *transport = pcb->transport;
337 error = transport->get_cid(transport->provider, &transport_cid);
338 if (error) {
339 if (m != NULL) {
340 mbuf_freem_list(m);
341 }
342 return error;
343 }
344 src.cid = transport_cid;
345 }
346
347 const uint16_t protocol = pcb->so->so_protocol;
348 const uint32_t buf_alloc = pcb->so->so_rcv.sb_hiwat;
349 const uint32_t fwd_cnt = pcb->fwd_cnt;
350
351 if (src.cid == dst.cid) {
352 pcb->last_buf_alloc = buf_alloc;
353 pcb->last_fwd_cnt = fwd_cnt;
354
355 socket_unlock(pcb->so, 0);
356 error = vsock_put_message(src, dst, operation, buf_alloc, fwd_cnt, m, protocol);
357 socket_lock(pcb->so, 0);
358 } else {
359 struct vsock_transport *transport = pcb->transport;
360 error = transport->put_message(transport->provider, src, dst, operation, buf_alloc, fwd_cnt, m);
361
362 if (!error) {
363 pcb->last_buf_alloc = buf_alloc;
364 pcb->last_fwd_cnt = fwd_cnt;
365 }
366 }
367
368 return error;
369 }
370
371 static errno_t
vsock_pcb_reset_address(struct vsock_address src,struct vsock_address dst,uint16_t protocol)372 vsock_pcb_reset_address(struct vsock_address src, struct vsock_address dst, uint16_t protocol)
373 {
374 if (dst.cid == VMADDR_CID_ANY || dst.port == VMADDR_PORT_ANY) {
375 return EINVAL;
376 }
377
378 errno_t error = 0;
379 struct vsock_transport *transport = NULL;
380
381 if (src.cid == VMADDR_CID_ANY) {
382 transport = os_atomic_load(&the_vsock_transport[protocol], relaxed);
383 if (transport == NULL) {
384 return ENODEV;
385 }
386
387 uint32_t transport_cid;
388 error = transport->get_cid(transport->provider, &transport_cid);
389 if (error) {
390 return error;
391 }
392 src.cid = transport_cid;
393 }
394
395 if (src.cid == dst.cid) {
396 // Reset both sockets.
397 struct vsockpcb *pcb = vsock_get_matching_pcb(src, dst, protocol);
398 if (pcb) {
399 socket_lock_assert_owned(pcb->so);
400 vsock_unbind_pcb(pcb);
401 socket_unlock(pcb->so, 1);
402 }
403 } else {
404 if (!transport) {
405 transport = os_atomic_load(&the_vsock_transport[protocol], relaxed);
406 if (transport == NULL) {
407 return ENODEV;
408 }
409 }
410 error = transport->put_message(transport->provider, src, dst, VSOCK_RESET, 0, 0, NULL);
411 }
412
413 return error;
414 }
415
416 static errno_t
vsock_pcb_safe_reset_address(struct vsockpcb * pcb,struct vsock_address src,struct vsock_address dst,uint16_t protocol)417 vsock_pcb_safe_reset_address(struct vsockpcb *pcb, struct vsock_address src, struct vsock_address dst, uint16_t protocol)
418 {
419 if (pcb) {
420 socket_lock_assert_owned(pcb->so);
421 socket_unlock(pcb->so, 0);
422 }
423 errno_t error = vsock_pcb_reset_address(src, dst, protocol);
424 if (pcb) {
425 socket_lock(pcb->so, 0);
426 }
427 return error;
428 }
429
430 static errno_t
vsock_pcb_connect(struct vsockpcb * pcb)431 vsock_pcb_connect(struct vsockpcb *pcb)
432 {
433 return vsock_pcb_send_message(pcb, VSOCK_REQUEST, NULL);
434 }
435
436 static errno_t
vsock_pcb_respond(struct vsockpcb * pcb)437 vsock_pcb_respond(struct vsockpcb *pcb)
438 {
439 return vsock_pcb_send_message(pcb, VSOCK_RESPONSE, NULL);
440 }
441
442 static errno_t
vsock_pcb_send(struct vsockpcb * pcb,mbuf_t m)443 vsock_pcb_send(struct vsockpcb *pcb, mbuf_t m)
444 {
445 return vsock_pcb_send_message(pcb, VSOCK_PAYLOAD, m);
446 }
447
448 static errno_t
vsock_pcb_shutdown_send(struct vsockpcb * pcb)449 vsock_pcb_shutdown_send(struct vsockpcb *pcb)
450 {
451 return vsock_pcb_send_message(pcb, VSOCK_SHUTDOWN_SEND, NULL);
452 }
453
454 static errno_t
vsock_pcb_reset(struct vsockpcb * pcb)455 vsock_pcb_reset(struct vsockpcb *pcb)
456 {
457 return vsock_pcb_send_message(pcb, VSOCK_RESET, NULL);
458 }
459
460 static errno_t
vsock_pcb_credit_update(struct vsockpcb * pcb)461 vsock_pcb_credit_update(struct vsockpcb *pcb)
462 {
463 return vsock_pcb_send_message(pcb, VSOCK_CREDIT_UPDATE, NULL);
464 }
465
466 static errno_t
vsock_pcb_credit_update_if_needed(struct vsockpcb * _Nonnull pcb)467 vsock_pcb_credit_update_if_needed(struct vsockpcb *_Nonnull pcb)
468 {
469 VERIFY(pcb != NULL);
470
471 // Sends a credit update if the credit values have changed since the last sent message.
472 if (pcb->so->so_rcv.sb_hiwat != pcb->last_buf_alloc || pcb->fwd_cnt != pcb->last_fwd_cnt) {
473 return vsock_pcb_credit_update(pcb);
474 }
475 return 0;
476 }
477
478 static errno_t
vsock_pcb_credit_request(struct vsockpcb * pcb)479 vsock_pcb_credit_request(struct vsockpcb *pcb)
480 {
481 return vsock_pcb_send_message(pcb, VSOCK_CREDIT_REQUEST, NULL);
482 }
483
484 static errno_t
vsock_disconnect_pcb_common(struct vsockpcb * pcb,bool is_locked)485 vsock_disconnect_pcb_common(struct vsockpcb *pcb, bool is_locked)
486 {
487 socket_lock_assert_owned(pcb->so);
488 vsock_unbind_pcb_locked(pcb, is_locked);
489 return vsock_pcb_reset(pcb);
490 }
491
492 static errno_t
vsock_disconnect_pcb_locked(struct vsockpcb * pcb)493 vsock_disconnect_pcb_locked(struct vsockpcb *pcb)
494 {
495 return vsock_disconnect_pcb_common(pcb, true);
496 }
497
498 static errno_t
vsock_disconnect_pcb(struct vsockpcb * pcb)499 vsock_disconnect_pcb(struct vsockpcb *pcb)
500 {
501 return vsock_disconnect_pcb_common(pcb, false);
502 }
503
504 static errno_t
vsock_sockaddr_vm_validate(struct vsockpcb * pcb,struct sockaddr_vm * addr,struct proc * p)505 vsock_sockaddr_vm_validate(struct vsockpcb *pcb, struct sockaddr_vm *addr, struct proc *p)
506 {
507 if (!pcb || !pcb->so || !addr) {
508 return EINVAL;
509 }
510
511 // Validate address length.
512 if (addr->svm_len < sizeof(struct sockaddr_vm)) {
513 return EINVAL;
514 }
515
516 // Validate address family.
517 if (addr->svm_family != AF_UNSPEC && addr->svm_family != AF_VSOCK) {
518 return EAFNOSUPPORT;
519 }
520
521 // Only stream is supported currently.
522 if (pcb->so->so_type != SOCK_STREAM) {
523 return EAFNOSUPPORT;
524 }
525
526 errno_t error = vsock_validate_entitlements(pcb->so->so_protocol, p);
527 if (error) {
528 return error;
529 }
530
531 return 0;
532 }
533
534 /* VSock Receive Handlers */
535
536 static errno_t
vsock_put_message_connected(struct vsockpcb * _Nonnull pcb,enum vsock_operation op,mbuf_t m)537 vsock_put_message_connected(struct vsockpcb *_Nonnull pcb, enum vsock_operation op, mbuf_t m)
538 {
539 VERIFY(pcb != NULL);
540 socket_lock_assert_owned(pcb->so);
541
542 errno_t error = 0;
543
544 switch (op) {
545 case VSOCK_SHUTDOWN:
546 socantsendmore(pcb->so);
547 socantrcvmore(pcb->so);
548 break;
549 case VSOCK_SHUTDOWN_RECEIVE:
550 socantsendmore(pcb->so);
551 break;
552 case VSOCK_SHUTDOWN_SEND:
553 socantrcvmore(pcb->so);
554 break;
555 case VSOCK_PAYLOAD:
556 // Add data to the receive queue then wakeup any reading threads.
557 error = !sbappendstream(&pcb->so->so_rcv, m);
558 if (!error) {
559 sorwakeup(pcb->so);
560 }
561 break;
562 case VSOCK_RESET:
563 vsock_unbind_pcb(pcb);
564 break;
565 default:
566 error = ENOTSUP;
567 break;
568 }
569
570 return error;
571 }
572
573 static errno_t
vsock_put_message_connecting(struct vsockpcb * _Nonnull pcb,enum vsock_operation op)574 vsock_put_message_connecting(struct vsockpcb *_Nonnull pcb, enum vsock_operation op)
575 {
576 VERIFY(pcb != NULL);
577 socket_lock_assert_owned(pcb->so);
578
579 errno_t error = 0;
580
581 switch (op) {
582 case VSOCK_RESPONSE:
583 soisconnected(pcb->so);
584 break;
585 case VSOCK_RESET:
586 pcb->so->so_error = EAGAIN;
587 error = vsock_disconnect_pcb(pcb);
588 break;
589 default:
590 vsock_disconnect_pcb(pcb);
591 error = ENOTSUP;
592 break;
593 }
594
595 return error;
596 }
597
598 static errno_t
vsock_put_message_listening(struct vsockpcb * _Nonnull pcb,enum vsock_operation op,struct vsock_address src,struct vsock_address dst)599 vsock_put_message_listening(struct vsockpcb *_Nonnull pcb, enum vsock_operation op, struct vsock_address src, struct vsock_address dst)
600 {
601 VERIFY(pcb != NULL);
602 socket_lock_assert_owned(pcb->so);
603
604 struct sockaddr_vm addr;
605 struct socket *so2 = NULL;
606 struct vsockpcb *pcb2 = NULL;
607
608 const uint16_t protocol = pcb->so->so_protocol;
609
610 errno_t error = 0;
611
612 switch (op) {
613 case VSOCK_REQUEST:
614 addr = (struct sockaddr_vm) {
615 .svm_len = sizeof(addr),
616 .svm_family = AF_VSOCK,
617 .svm_reserved1 = 0,
618 .svm_port = pcb->local_address.port,
619 .svm_cid = pcb->local_address.cid
620 };
621 so2 = sonewconn(pcb->so, 0, (struct sockaddr *)&addr);
622 if (!so2) {
623 // It is likely that the backlog is full. Deny this request.
624 vsock_pcb_safe_reset_address(pcb, dst, src, protocol);
625 error = ECONNREFUSED;
626 break;
627 }
628
629 pcb2 = sotovsockpcb(so2);
630 if (!pcb2) {
631 error = EINVAL;
632 goto done;
633 }
634
635 error = vsock_bind_address(pcb2, dst, src);
636 if (error) {
637 goto done;
638 }
639
640 error = vsock_pcb_respond(pcb2);
641 if (error) {
642 goto done;
643 }
644
645 soisconnected(so2);
646
647 done:
648 if (error) {
649 if (pcb2) {
650 vsock_unbind_pcb(pcb2);
651 } else {
652 soisdisconnected(so2);
653 }
654 socket_unlock(so2, 1);
655 vsock_pcb_reset_address(dst, src, protocol);
656 } else {
657 socket_unlock(so2, 0);
658 }
659 socket_lock(pcb->so, 0);
660
661 break;
662 case VSOCK_RESET:
663 error = vsock_pcb_safe_reset_address(pcb, dst, src, protocol);
664 break;
665 default:
666 vsock_pcb_safe_reset_address(pcb, dst, src, protocol);
667 error = ENOTSUP;
668 break;
669 }
670
671 return error;
672 }
673
674 /* VSock Transport */
675
676 errno_t
vsock_add_transport(struct vsock_transport * transport)677 vsock_add_transport(struct vsock_transport *transport)
678 {
679 if (transport == NULL || transport->provider == NULL || transport->protocol >= VSOCK_PROTO_MAX) {
680 return EINVAL;
681 }
682 if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport[transport->protocol], NULL, transport, acq_rel)) {
683 return EEXIST;
684 }
685 return 0;
686 }
687
688 errno_t
vsock_remove_transport(struct vsock_transport * transport)689 vsock_remove_transport(struct vsock_transport *transport)
690 {
691 if (!os_atomic_cmpxchg((void * volatile *)&the_vsock_transport[transport->protocol], transport, NULL, acq_rel)) {
692 return ENODEV;
693 }
694 return 0;
695 }
696
697 errno_t
vsock_reset_transport(struct vsock_transport * transport)698 vsock_reset_transport(struct vsock_transport *transport)
699 {
700 if (transport == NULL) {
701 return EINVAL;
702 }
703
704 errno_t error = 0;
705 struct vsockpcb *pcb = NULL;
706 struct vsockpcb *tmp_pcb = NULL;
707
708 lck_rw_lock_exclusive(&vsockinfo[transport->protocol].bound_lock);
709 LIST_FOREACH_SAFE(pcb, &vsockinfo[transport->protocol].bound, bound, tmp_pcb) {
710 // Disconnect this transport's sockets. Listen and bind sockets must stay alive.
711 socket_lock(pcb->so, 1);
712 if (pcb->transport == transport && pcb->so->so_state & (SS_ISCONNECTED | SS_ISCONNECTING | SS_ISDISCONNECTING)) {
713 errno_t dc_error = vsock_disconnect_pcb_locked(pcb);
714 if (dc_error && !error) {
715 error = dc_error;
716 }
717 }
718 socket_unlock(pcb->so, 1);
719 }
720 lck_rw_done(&vsockinfo[transport->protocol].bound_lock);
721
722 return error;
723 }
724
725 errno_t
vsock_put_message(struct vsock_address src,struct vsock_address dst,enum vsock_operation op,uint32_t buf_alloc,uint32_t fwd_cnt,mbuf_t m,uint16_t protocol)726 vsock_put_message(struct vsock_address src, struct vsock_address dst, enum vsock_operation op, uint32_t buf_alloc, uint32_t fwd_cnt, mbuf_t m, uint16_t protocol)
727 {
728 struct vsockpcb *pcb = vsock_get_matching_pcb(dst, src, protocol);
729 if (!pcb) {
730 if (op != VSOCK_RESET) {
731 vsock_pcb_reset_address(dst, src, protocol);
732 }
733 if (m != NULL) {
734 mbuf_freem_list(m);
735 }
736 return EINVAL;
737 }
738
739 socket_lock_assert_owned(pcb->so);
740
741 struct socket *so = pcb->so;
742 errno_t error = 0;
743
744 // Check if the peer's buffer has changed. Update our view of the peer's forwarded bytes.
745 int buffers_changed = (pcb->peer_buf_alloc != buf_alloc) || (pcb->peer_fwd_cnt) != fwd_cnt;
746 pcb->peer_buf_alloc = buf_alloc;
747 pcb->peer_fwd_cnt = fwd_cnt;
748
749 // Peer's buffer has enough space for the next packet. Notify any threads waiting for space.
750 if (buffers_changed && vsock_get_peer_space(pcb) >= pcb->waiting_send_size) {
751 sowwakeup(so);
752 }
753
754 switch (op) {
755 case VSOCK_CREDIT_REQUEST:
756 error = vsock_pcb_credit_update(pcb);
757 break;
758 case VSOCK_CREDIT_UPDATE:
759 break;
760 default:
761 if (so->so_state & SS_ISCONNECTED) {
762 error = vsock_put_message_connected(pcb, op, m);
763 m = NULL;
764 } else if (so->so_state & SS_ISCONNECTING) {
765 error = vsock_put_message_connecting(pcb, op);
766 } else if (so->so_options & SO_ACCEPTCONN) {
767 error = vsock_put_message_listening(pcb, op, src, dst);
768 } else {
769 // Reset the connection for other states such as 'disconnecting'.
770 error = vsock_disconnect_pcb(pcb);
771 if (!error) {
772 error = ENODEV;
773 }
774 }
775 break;
776 }
777 socket_unlock(so, 1);
778
779 if (m != NULL) {
780 mbuf_freem_list(m);
781 }
782
783 return error;
784 }
785
786 /* VSock Sysctl */
787
788 static int
common_vsock_pcblist(struct sysctl_oid * oidp __unused,void * arg1,int arg2 __unused,struct sysctl_req * _Nonnull req,uint16_t protocol)789 common_vsock_pcblist(struct sysctl_oid *oidp __unused, void *arg1, int arg2 __unused, struct sysctl_req *_Nonnull req, uint16_t protocol)
790 {
791 #pragma unused(oidp,arg2)
792 VERIFY(req != NULL);
793
794 int error;
795
796 // Only stream is supported.
797 if ((intptr_t)arg1 != SOCK_STREAM) {
798 return EINVAL;
799 }
800
801 // Get the generation count and the count of all vsock sockets.
802 lck_rw_lock_shared(&vsockinfo[protocol].all_lock);
803 uint64_t n = vsockinfo[protocol].all_pcb_count;
804 vsock_gen_t gen_count = vsockinfo[protocol].vsock_gencnt;
805 lck_rw_done(&vsockinfo[protocol].all_lock);
806
807 const size_t xpcb_len = sizeof(struct xvsockpcb);
808 struct xvsockpgen xvg;
809
810 /*
811 * The process of preparing the PCB list is too time-consuming and
812 * resource-intensive to repeat twice on every request.
813 */
814 if (req->oldptr == USER_ADDR_NULL) {
815 req->oldidx = (size_t)(2 * sizeof(xvg) + (n + n / 8) * xpcb_len);
816 return 0;
817 }
818
819 if (req->newptr != USER_ADDR_NULL) {
820 return EPERM;
821 }
822
823 bzero(&xvg, sizeof(xvg));
824 xvg.xvg_len = sizeof(xvg);
825 xvg.xvg_count = n;
826 xvg.xvg_gen = gen_count;
827 xvg.xvg_sogen = so_gencnt;
828 error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
829 if (error) {
830 return error;
831 }
832
833 // Return if no sockets exist.
834 if (n == 0) {
835 return 0;
836 }
837
838 lck_rw_lock_shared(&vsockinfo[protocol].all_lock);
839
840 n = 0;
841 struct vsockpcb *pcb = NULL;
842 TAILQ_FOREACH(pcb, &vsockinfo[protocol].all, all) {
843 // Bail if there is not enough user buffer for this next socket.
844 if (req->oldlen - req->oldidx - sizeof(xvg) < xpcb_len) {
845 break;
846 }
847
848 // Populate the socket structure.
849 socket_lock(pcb->so, 1);
850 if (pcb->vsock_gencnt <= gen_count) {
851 struct xvsockpcb xpcb;
852 bzero(&xpcb, xpcb_len);
853 xpcb.xv_len = xpcb_len;
854 xpcb.xv_vsockpp = (uint64_t)VM_KERNEL_ADDRHASH(pcb);
855 xpcb.xvp_local_cid = pcb->local_address.cid;
856 xpcb.xvp_local_port = pcb->local_address.port;
857 xpcb.xvp_remote_cid = pcb->remote_address.cid;
858 xpcb.xvp_remote_port = pcb->remote_address.port;
859 xpcb.xvp_rxcnt = pcb->fwd_cnt;
860 xpcb.xvp_txcnt = pcb->tx_cnt;
861 xpcb.xvp_peer_rxhiwat = pcb->peer_buf_alloc;
862 xpcb.xvp_peer_rxcnt = pcb->peer_fwd_cnt;
863 xpcb.xvp_last_pid = pcb->so->last_pid;
864 xpcb.xvp_gencnt = pcb->vsock_gencnt;
865 if (pcb->so) {
866 sotoxsocket(pcb->so, &xpcb.xv_socket);
867 }
868 socket_unlock(pcb->so, 1);
869
870 error = SYSCTL_OUT(req, &xpcb, xpcb_len);
871 if (error != 0) {
872 break;
873 }
874 n++;
875 } else {
876 socket_unlock(pcb->so, 1);
877 }
878 }
879
880 // Update the generation count to match the sockets being returned.
881 gen_count = vsockinfo[protocol].vsock_gencnt;
882
883 lck_rw_done(&vsockinfo[protocol].all_lock);
884
885 if (!error) {
886 /*
887 * Give the user an updated idea of our state.
888 * If the generation differs from what we told
889 * her before, she knows that something happened
890 * while we were processing this request, and it
891 * might be necessary to retry.
892 */
893 bzero(&xvg, sizeof(xvg));
894 xvg.xvg_len = sizeof(xvg);
895 xvg.xvg_count = n;
896 xvg.xvg_gen = gen_count;
897 xvg.xvg_sogen = so_gencnt;
898 error = SYSCTL_OUT(req, &xvg, sizeof(xvg));
899 }
900
901 return error;
902 }
903
904 static int
905 vsock_pcblist SYSCTL_HANDLER_ARGS
906 {
907 return common_vsock_pcblist(oidp, arg1, arg2, req, VSOCK_PROTO_STANDARD);
908 }
909
910 static int
911 vsock_private_pcblist SYSCTL_HANDLER_ARGS
912 {
913 return common_vsock_pcblist(oidp, arg1, arg2, req, VSOCK_PROTO_PRIVATE);
914 }
915
916 #ifdef SYSCTL_DECL
917 // Standard namespace.
918 SYSCTL_NODE(_net, OID_AUTO, vsock, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock");
919 SYSCTL_UINT(_net_vsock, OID_AUTO, sendspace, CTLFLAG_RW | CTLFLAG_LOCKED,
920 &vsock_sendspace[VSOCK_PROTO_STANDARD], 0, "Maximum outgoing vsock datagram size");
921 SYSCTL_UINT(_net_vsock, OID_AUTO, recvspace, CTLFLAG_RW | CTLFLAG_LOCKED,
922 &vsock_recvspace[VSOCK_PROTO_STANDARD], 0, "Maximum incoming vsock datagram size");
923 SYSCTL_PROC(_net_vsock, OID_AUTO, pcblist,
924 CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
925 __unsafe_forge_single(caddr_t, SOCK_STREAM), 0, vsock_pcblist, "S,xvsockpcb",
926 "List of active vsock sockets");
927 SYSCTL_UINT(_net_vsock, OID_AUTO, pcbcount, CTLFLAG_RD | CTLFLAG_LOCKED,
928 (u_int *)&vsockinfo[VSOCK_PROTO_STANDARD].all_pcb_count, 0, "");
929
930 // Private namespace.
931 SYSCTL_NODE(_net, OID_AUTO, vsock_private, CTLFLAG_RW | CTLFLAG_LOCKED, 0, "vsock_private");
932 SYSCTL_PROC(_net_vsock_private, OID_AUTO, pcblist,
933 CTLTYPE_STRUCT | CTLFLAG_RD | CTLFLAG_LOCKED,
934 __unsafe_forge_single(caddr_t, SOCK_STREAM), 0, vsock_private_pcblist, "S,xvsockpcb",
935 "List of active private vsock sockets");
936 SYSCTL_UINT(_net_vsock_private, OID_AUTO, pcbcount, CTLFLAG_RD | CTLFLAG_LOCKED,
937 (u_int *)&vsockinfo[VSOCK_PROTO_PRIVATE].all_pcb_count, 0, "");
938 #endif
939
940 /* VSock Protocol */
941
942 static int
vsock_attach(struct socket * _Nonnull so,int proto,struct proc * p)943 vsock_attach(struct socket *_Nonnull so, int proto, struct proc *p)
944 {
945 #pragma unused(proto, p)
946 VERIFY(so != NULL);
947
948 const uint16_t protocol = so->so_protocol;
949 if (protocol >= VSOCK_PROTO_MAX) {
950 return EINVAL;
951 }
952
953 errno_t error = vsock_validate_entitlements(protocol, p);
954 if (error) {
955 return error;
956 }
957
958 const uint32_t send_space = vsock_sendspace[protocol];
959 const uint32_t receive_space = vsock_recvspace[protocol];
960 if (send_space == 0 || receive_space == 0) {
961 return ENOMEM;
962 }
963
964 // Reserve send and receive buffers.
965 error = soreserve(so, send_space, receive_space);
966 if (error) {
967 return error;
968 }
969
970 // Attach should only be run once per socket.
971 struct vsockpcb *pcb = sotovsockpcb(so);
972 if (pcb) {
973 return EINVAL;
974 }
975
976 // Get the transport for this socket.
977 struct vsock_transport *transport = os_atomic_load(&the_vsock_transport[protocol], relaxed);
978 if (transport == NULL) {
979 return ENODEV;
980 }
981
982 // Initialize the vsock protocol control block.
983 pcb = zalloc_flags(vsockpcb_zone, Z_WAITOK | Z_ZERO | Z_NOFAIL);
984 pcb->so = so;
985 pcb->transport = transport;
986 pcb->local_address = (struct vsock_address) {
987 .cid = VMADDR_CID_ANY,
988 .port = VMADDR_PORT_ANY
989 };
990 pcb->remote_address = (struct vsock_address) {
991 .cid = VMADDR_CID_ANY,
992 .port = VMADDR_PORT_ANY
993 };
994 so->so_pcb = pcb;
995
996 // Tell the transport that this socket has attached.
997 error = transport->attach_socket(transport->provider);
998 if (error) {
999 zfree(vsockpcb_zone, pcb);
1000 so->so_pcb = NULL;
1001 return error;
1002 }
1003
1004 // Add to the list of all vsock sockets.
1005 lck_rw_lock_exclusive(&vsockinfo[protocol].all_lock);
1006 TAILQ_INSERT_TAIL(&vsockinfo[protocol].all, pcb, all);
1007 vsockinfo[protocol].all_pcb_count++;
1008 pcb->vsock_gencnt = ++vsockinfo[protocol].vsock_gencnt;
1009 lck_rw_done(&vsockinfo[protocol].all_lock);
1010
1011 return 0;
1012 }
1013
1014 static int
vsock_control(struct socket * so,u_long cmd,caddr_t __sized_by (IOCPARM_LEN (cmd))data,struct ifnet * ifp,struct proc * p)1015 vsock_control(struct socket *so, u_long cmd, caddr_t __sized_by(IOCPARM_LEN(cmd)) data, struct ifnet *ifp, struct proc *p)
1016 {
1017 #pragma unused(ifp, p)
1018
1019 VERIFY(so != NULL);
1020
1021 if (cmd != IOCTL_VM_SOCKETS_GET_LOCAL_CID) {
1022 return EINVAL;
1023 }
1024
1025 if (so == NULL) {
1026 return EINVAL;
1027 }
1028
1029 struct vsockpcb *pcb = sotovsockpcb(so);
1030 if (pcb == NULL) {
1031 return EINVAL;
1032 }
1033
1034 struct vsock_transport *transport = pcb->transport;
1035 if (transport == NULL) {
1036 return ENODEV;
1037 }
1038
1039 uint32_t transport_cid;
1040 errno_t error = transport->get_cid(transport->provider, &transport_cid);
1041 if (error) {
1042 return error;
1043 }
1044
1045 memcpy(data, &transport_cid, sizeof(transport_cid));
1046
1047 return 0;
1048 }
1049
1050 static int
vsock_detach(struct socket * so)1051 vsock_detach(struct socket *so)
1052 {
1053 struct vsockpcb *pcb = sotovsockpcb(so);
1054 if (pcb == NULL) {
1055 return EINVAL;
1056 }
1057
1058 vsock_unbind_pcb(pcb);
1059
1060 // Tell the transport that this socket has detached.
1061 struct vsock_transport *transport = pcb->transport;
1062 errno_t error = transport->detach_socket(transport->provider);
1063 if (error) {
1064 return error;
1065 }
1066
1067 const uint16_t protocol = so->so_protocol;
1068
1069 // Mark this socket for deallocation.
1070 so->so_flags |= SOF_PCBCLEARING;
1071
1072 // Reorder locks.
1073 socket_unlock(so, 0);
1074 lck_rw_lock_exclusive(&vsockinfo[protocol].all_lock);
1075 socket_lock(so, 0);
1076
1077 // Remove from the list of all vsock sockets.
1078 TAILQ_REMOVE(&vsockinfo[protocol].all, pcb, all);
1079 pcb->all.tqe_next = NULL;
1080 pcb->all.tqe_prev = NULL;
1081 vsockinfo[protocol].all_pcb_count--;
1082 vsockinfo[protocol].vsock_gencnt++;
1083 lck_rw_done(&vsockinfo[protocol].all_lock);
1084
1085 return 0;
1086 }
1087
1088 static int
vsock_abort(struct socket * so)1089 vsock_abort(struct socket *so)
1090 {
1091 return vsock_detach(so);
1092 }
1093
1094 static int
vsock_bind(struct socket * so,struct sockaddr * nam,struct proc * p)1095 vsock_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
1096 {
1097 #pragma unused(p)
1098
1099 struct vsockpcb *pcb = sotovsockpcb(so);
1100 if (pcb == NULL) {
1101 return EINVAL;
1102 }
1103
1104 struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1105
1106 errno_t error = vsock_sockaddr_vm_validate(pcb, addr, p);
1107 if (error) {
1108 return error;
1109 }
1110
1111 struct vsock_address laddr = (struct vsock_address) {
1112 .cid = addr->svm_cid,
1113 .port = addr->svm_port,
1114 };
1115
1116 struct vsock_address raddr = (struct vsock_address) {
1117 .cid = VMADDR_CID_ANY,
1118 .port = VMADDR_PORT_ANY,
1119 };
1120
1121 error = vsock_bind_address(pcb, laddr, raddr);
1122 if (error) {
1123 return error;
1124 }
1125
1126 return 0;
1127 }
1128
1129 static int
vsock_listen(struct socket * so,struct proc * p)1130 vsock_listen(struct socket *so, struct proc *p)
1131 {
1132 #pragma unused(p)
1133
1134 struct vsockpcb *pcb = sotovsockpcb(so);
1135 if (pcb == NULL) {
1136 return EINVAL;
1137 }
1138
1139 // Only stream is supported currently.
1140 if (so->so_type != SOCK_STREAM) {
1141 return EAFNOSUPPORT;
1142 }
1143
1144 struct vsock_address *addr = &pcb->local_address;
1145
1146 if (addr->port == VMADDR_CID_ANY) {
1147 return EFAULT;
1148 }
1149
1150 struct vsock_transport *transport = pcb->transport;
1151 uint32_t transport_cid;
1152 errno_t error = transport->get_cid(transport->provider, &transport_cid);
1153 if (error) {
1154 return error;
1155 }
1156
1157 // Can listen on the transport's cid or any.
1158 if (addr->cid != transport_cid && addr->cid != VMADDR_CID_ANY) {
1159 return EFAULT;
1160 }
1161
1162 return 0;
1163 }
1164
1165 static int
vsock_accept(struct socket * so,struct sockaddr ** nam)1166 vsock_accept(struct socket *so, struct sockaddr **nam)
1167 {
1168 struct vsockpcb *pcb = sotovsockpcb(so);
1169 if (pcb == NULL) {
1170 return EINVAL;
1171 }
1172
1173 // Do not accept disconnected sockets.
1174 if (so->so_state & SS_ISDISCONNECTED) {
1175 return ECONNABORTED;
1176 }
1177
1178 *nam = vsock_new_sockaddr(&pcb->remote_address);
1179
1180 return 0;
1181 }
1182
1183 static int
vsock_connect(struct socket * so,struct sockaddr * nam,struct proc * p)1184 vsock_connect(struct socket *so, struct sockaddr *nam, struct proc *p)
1185 {
1186 #pragma unused(p)
1187
1188 struct vsockpcb *pcb = sotovsockpcb(so);
1189 if (pcb == NULL) {
1190 return EINVAL;
1191 }
1192
1193 struct sockaddr_vm *addr = (struct sockaddr_vm *)nam;
1194
1195 errno_t error = vsock_sockaddr_vm_validate(pcb, addr, p);
1196 if (error) {
1197 return error;
1198 }
1199
1200 uint32_t transport_cid;
1201 struct vsock_transport *transport = pcb->transport;
1202 error = transport->get_cid(transport->provider, &transport_cid);
1203 if (error) {
1204 return error;
1205 }
1206
1207 // Only supporting connections to the host, hypervisor, or self for now.
1208 if (addr->svm_cid != VMADDR_CID_HOST &&
1209 addr->svm_cid != VMADDR_CID_HYPERVISOR &&
1210 addr->svm_cid != transport_cid) {
1211 return EFAULT;
1212 }
1213
1214 soisconnecting(so);
1215
1216 // Set the remote and local address.
1217 struct vsock_address remote_addr = (struct vsock_address) {
1218 .cid = addr->svm_cid,
1219 .port = addr->svm_port,
1220 };
1221
1222 struct vsock_address local_addr = (struct vsock_address) {
1223 .cid = transport_cid,
1224 .port = VMADDR_PORT_ANY,
1225 };
1226
1227 // Bind to the address.
1228 error = vsock_bind_address(pcb, local_addr, remote_addr);
1229 if (error) {
1230 goto cleanup;
1231 }
1232
1233 // Attempt a connection using the socket's transport.
1234 error = vsock_pcb_connect(pcb);
1235 if (error) {
1236 goto cleanup;
1237 }
1238
1239 if ((so->so_state & SS_ISCONNECTED) == 0) {
1240 // Don't wait for peer's response if non-blocking.
1241 if (so->so_state & SS_NBIO) {
1242 goto done;
1243 }
1244
1245 struct timespec ts = (struct timespec) {
1246 .tv_sec = so->so_snd.sb_timeo.tv_sec,
1247 .tv_nsec = so->so_snd.sb_timeo.tv_usec * 1000,
1248 };
1249
1250 lck_mtx_t *mutex_held;
1251 if (so->so_proto->pr_getlock != NULL) {
1252 mutex_held = (*so->so_proto->pr_getlock)(so, PR_F_WILLUNLOCK);
1253 } else {
1254 mutex_held = so->so_proto->pr_domain->dom_mtx;
1255 }
1256
1257 // Wait until we receive a response to the connect request.
1258 error = msleep((caddr_t)&so->so_timeo, mutex_held, PSOCK | PCATCH, "vsock_connect", &ts);
1259 if (error) {
1260 if (error == EAGAIN) {
1261 error = ETIMEDOUT;
1262 }
1263 goto cleanup;
1264 }
1265 }
1266
1267 cleanup:
1268 if (so->so_error && !error) {
1269 error = so->so_error;
1270 so->so_error = 0;
1271 }
1272 if (!error) {
1273 error = !(so->so_state & SS_ISCONNECTED);
1274 }
1275 if (error) {
1276 vsock_unbind_pcb(pcb);
1277 }
1278
1279 done:
1280 return error;
1281 }
1282
1283 static int
vsock_disconnect(struct socket * so)1284 vsock_disconnect(struct socket *so)
1285 {
1286 struct vsockpcb *pcb = sotovsockpcb(so);
1287 if (pcb == NULL) {
1288 return EINVAL;
1289 }
1290
1291 return vsock_disconnect_pcb(pcb);
1292 }
1293
1294 static int
vsock_sockaddr(struct socket * so,struct sockaddr ** nam)1295 vsock_sockaddr(struct socket *so, struct sockaddr **nam)
1296 {
1297 struct vsockpcb *pcb = sotovsockpcb(so);
1298 if (pcb == NULL) {
1299 return EINVAL;
1300 }
1301
1302 *nam = vsock_new_sockaddr(&pcb->local_address);
1303
1304 return 0;
1305 }
1306
1307 static int
vsock_peeraddr(struct socket * so,struct sockaddr ** nam)1308 vsock_peeraddr(struct socket *so, struct sockaddr **nam)
1309 {
1310 struct vsockpcb *pcb = sotovsockpcb(so);
1311 if (pcb == NULL) {
1312 return EINVAL;
1313 }
1314
1315 *nam = vsock_new_sockaddr(&pcb->remote_address);
1316
1317 return 0;
1318 }
1319
1320 static int
vsock_send(struct socket * so,int flags,struct mbuf * m,struct sockaddr * nam,struct mbuf * control,proc_t p)1321 vsock_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *nam, struct mbuf *control, proc_t p)
1322 {
1323 #pragma unused(flags, nam, p)
1324
1325 errno_t error = 0;
1326 struct vsockpcb *pcb = sotovsockpcb(so);
1327 if (pcb == NULL || m == NULL) {
1328 error = EINVAL;
1329 goto out;
1330 }
1331
1332 if (control != NULL) {
1333 error = EOPNOTSUPP;
1334 goto out;
1335 }
1336
1337 // Ensure this socket is connected.
1338 if ((so->so_state & SS_ISCONNECTED) == 0) {
1339 error = EPERM;
1340 goto out;
1341 }
1342
1343 // rdar://84098487 (SEED: Web: Virtio-socket sent data lost after 128KB)
1344 // For writes larger than the default `sosendmaxchain` of 65536, vsock_send() is called multiple times per write().
1345 // Only the first call to vsock_send() is passed a valid mbuf packet, while subsequent calls are not marked as a packet
1346 // with a valid length. We should mark all mbufs as a packet and set the correct packet length so that the downstream
1347 // socket transport layer can correctly generate physical segments.
1348 if (!(mbuf_flags(m) & MBUF_PKTHDR)) {
1349 if (!(mbuf_flags(m) & M_EXT)) {
1350 struct mbuf *header = NULL;
1351 MGETHDR(header, M_WAITOK, MT_HEADER);
1352 if (header == NULL) {
1353 error = ENOBUFS;
1354 goto out;
1355 }
1356 header->m_next = m;
1357 m = header;
1358 } else {
1359 mbuf_setflags(m, mbuf_flags(m) | MBUF_PKTHDR);
1360 }
1361
1362 size_t len = 0;
1363 struct mbuf *next = m;
1364 while (next) {
1365 len += mbuf_len(next);
1366 next = mbuf_next(next);
1367 }
1368 mbuf_pkthdr_setlen(m, len);
1369 }
1370
1371 const size_t len = mbuf_pkthdr_len(m);
1372 uint32_t free_space = vsock_get_peer_space(pcb);
1373
1374 // Ensure the peer has enough space in their receive buffer.
1375 while (len > free_space) {
1376 // Record the number of free peer bytes necessary before we can send.
1377 if (len > pcb->waiting_send_size) {
1378 pcb->waiting_send_size = len;
1379 }
1380
1381 // Send a credit request.
1382 error = vsock_pcb_credit_request(pcb);
1383 if (error) {
1384 goto out;
1385 }
1386
1387 // Check again in case free space was automatically updated in loopback case.
1388 free_space = vsock_get_peer_space(pcb);
1389 if (len <= free_space) {
1390 pcb->waiting_send_size = 0;
1391 break;
1392 }
1393
1394 // Bail if this is a non-blocking socket.
1395 if (so->so_state & SS_NBIO) {
1396 error = EWOULDBLOCK;
1397 goto out;
1398 }
1399
1400 // Wait until our peer has enough free space in their receive buffer.
1401 error = sbwait(&so->so_snd);
1402 pcb->waiting_send_size = 0;
1403 if (error) {
1404 goto out;
1405 }
1406
1407 // Bail if an error occured or we can't send more.
1408 if (so->so_state & SS_CANTSENDMORE) {
1409 error = EPIPE;
1410 goto out;
1411 } else if (so->so_error) {
1412 error = so->so_error;
1413 so->so_error = 0;
1414 goto out;
1415 }
1416
1417 free_space = vsock_get_peer_space(pcb);
1418 }
1419
1420 // Send a payload over the transport.
1421 error = vsock_pcb_send(pcb, m);
1422 if (error) {
1423 return error;
1424 }
1425
1426 pcb->tx_cnt += len;
1427
1428 return 0;
1429
1430 out:
1431 if (control != NULL) {
1432 m_freem(control);
1433 }
1434 if (m != NULL) {
1435 mbuf_freem_list(m);
1436 }
1437 return error;
1438 }
1439
1440 static int
vsock_shutdown(struct socket * so)1441 vsock_shutdown(struct socket *so)
1442 {
1443 struct vsockpcb *pcb = sotovsockpcb(so);
1444 if (pcb == NULL) {
1445 return EINVAL;
1446 }
1447
1448 socantsendmore(so);
1449
1450 // Tell peer we will no longer send.
1451 errno_t error = vsock_pcb_shutdown_send(pcb);
1452 if (error) {
1453 return error;
1454 }
1455
1456 return 0;
1457 }
1458
1459 static int
vsock_soreceive(struct socket * so,struct sockaddr ** psa,struct uio * uio,struct mbuf ** mp0,struct mbuf ** controlp,int * flagsp)1460 vsock_soreceive(struct socket *so, struct sockaddr **psa, struct uio *uio,
1461 struct mbuf **mp0, struct mbuf **controlp, int *flagsp)
1462 {
1463 struct vsockpcb *pcb = sotovsockpcb(so);
1464 if (pcb == NULL) {
1465 return EINVAL;
1466 }
1467
1468 user_ssize_t length = uio_resid(uio);
1469 int result = soreceive(so, psa, uio, mp0, controlp, flagsp);
1470 length -= uio_resid(uio);
1471
1472 socket_lock(so, 1);
1473
1474 pcb->fwd_cnt += length;
1475
1476 const uint32_t threshold = VSOCK_MAX_PACKET_SIZE;
1477
1478 // Send a credit update if it is possible that the peer will no longer send.
1479 if ((pcb->fwd_cnt - pcb->last_fwd_cnt + threshold) >= pcb->last_buf_alloc) {
1480 errno_t error = vsock_pcb_credit_update_if_needed(pcb);
1481 if (!result && error) {
1482 result = error;
1483 }
1484 }
1485
1486 socket_unlock(so, 1);
1487
1488 return result;
1489 }
1490
1491 static struct pr_usrreqs vsock_usrreqs = {
1492 .pru_abort = vsock_abort,
1493 .pru_attach = vsock_attach,
1494 .pru_control = vsock_control,
1495 .pru_detach = vsock_detach,
1496 .pru_bind = vsock_bind,
1497 .pru_listen = vsock_listen,
1498 .pru_accept = vsock_accept,
1499 .pru_connect = vsock_connect,
1500 .pru_disconnect = vsock_disconnect,
1501 .pru_send = vsock_send,
1502 .pru_shutdown = vsock_shutdown,
1503 .pru_sockaddr = vsock_sockaddr,
1504 .pru_peeraddr = vsock_peeraddr,
1505 .pru_sosend = sosend,
1506 .pru_soreceive = vsock_soreceive,
1507 };
1508
1509 static void
common_vsock_init(struct protosw * pp,struct domain * dp,uint16_t protocol,lck_grp_t * lock_group)1510 common_vsock_init(struct protosw *pp, struct domain *dp, uint16_t protocol, lck_grp_t *lock_group)
1511 {
1512 #pragma unused(dp)
1513
1514 static int vsock_initialized[VSOCK_PROTO_MAX] = {0};
1515 VERIFY((pp->pr_flags & (PR_INITIALIZED | PR_ATTACHED)) == PR_ATTACHED);
1516 if (!os_atomic_cmpxchg((volatile int *)&vsock_initialized[protocol], 0, 1, acq_rel)) {
1517 return;
1518 }
1519
1520 // Setup VSock protocol info struct.
1521 lck_rw_init(&vsockinfo[protocol].all_lock, lock_group, LCK_ATTR_NULL);
1522 lck_rw_init(&vsockinfo[protocol].bound_lock, lock_group, LCK_ATTR_NULL);
1523 lck_mtx_init(&vsockinfo[protocol].port_lock, lock_group, LCK_ATTR_NULL);
1524 TAILQ_INIT(&vsockinfo[protocol].all);
1525 LIST_INIT(&vsockinfo[protocol].bound);
1526 vsockinfo[protocol].last_port = VMADDR_PORT_ANY;
1527 }
1528
1529 static void
vsock_init(struct protosw * pp,struct domain * dp)1530 vsock_init(struct protosw *pp, struct domain *dp)
1531 {
1532 static LCK_GRP_DECLARE(vsock_lock_grp, "vsock");
1533 common_vsock_init(pp, dp, VSOCK_PROTO_STANDARD, &vsock_lock_grp);
1534 }
1535
1536 static void
vsock_private_init(struct protosw * pp,struct domain * dp)1537 vsock_private_init(struct protosw *pp, struct domain *dp)
1538 {
1539 static LCK_GRP_DECLARE(vsock_private_lock_grp, "vsock_private");
1540 common_vsock_init(pp, dp, VSOCK_PROTO_PRIVATE, &vsock_private_lock_grp);
1541 }
1542
1543 static int
vsock_sofreelastref(struct socket * so,int dealloc)1544 vsock_sofreelastref(struct socket *so, int dealloc)
1545 {
1546 socket_lock_assert_owned(so);
1547
1548 struct vsockpcb *pcb = sotovsockpcb(so);
1549 if (pcb != NULL) {
1550 zfree(vsockpcb_zone, pcb);
1551 }
1552
1553 so->so_pcb = NULL;
1554 sofreelastref(so, dealloc);
1555
1556 return 0;
1557 }
1558
1559 static int
vsock_unlock(struct socket * _Nonnull so,int refcount,void * lr_saved)1560 vsock_unlock(struct socket *_Nonnull so, int refcount, void *lr_saved)
1561 {
1562 VERIFY(so != NULL);
1563
1564 lck_mtx_t *mutex_held = so->so_proto->pr_domain->dom_mtx;
1565 #ifdef MORE_LOCKING_DEBUG
1566 LCK_MTX_ASSERT(mutex_held, LCK_MTX_ASSERT_OWNED);
1567 #endif
1568 so->unlock_lr[so->next_unlock_lr] = lr_saved;
1569 so->next_unlock_lr = (so->next_unlock_lr + 1) % SO_LCKDBG_MAX;
1570
1571 if (refcount) {
1572 if (so->so_usecount <= 0) {
1573 panic("%s: bad refcount=%d so=%p (%d, %d, %d) "
1574 "lrh=%s", __func__, so->so_usecount, so,
1575 SOCK_DOM(so), so->so_type,
1576 SOCK_PROTO(so), solockhistory_nr(so));
1577 /* NOTREACHED */
1578 }
1579
1580 so->so_usecount--;
1581 if (so->so_usecount == 0) {
1582 vsock_sofreelastref(so, 1);
1583 }
1584 }
1585 lck_mtx_unlock(mutex_held);
1586
1587 return 0;
1588 }
1589
1590 static struct protosw vsocksw[VSOCK_PROTO_MAX] = {
1591 {
1592 .pr_type = SOCK_STREAM,
1593 .pr_protocol = VSOCK_PROTO_STANDARD,
1594 .pr_flags = PR_CONNREQUIRED | PR_WANTRCVD,
1595 .pr_init = vsock_init,
1596 .pr_unlock = vsock_unlock,
1597 .pr_usrreqs = &vsock_usrreqs,
1598 },
1599 {
1600 .pr_type = SOCK_STREAM,
1601 .pr_protocol = VSOCK_PROTO_PRIVATE,
1602 .pr_flags = PR_CONNREQUIRED | PR_WANTRCVD,
1603 .pr_init = vsock_private_init,
1604 .pr_unlock = vsock_unlock,
1605 .pr_usrreqs = &vsock_usrreqs,
1606 }
1607 };
1608
1609 static const int vsock_proto_count = (sizeof(vsocksw) / sizeof(struct protosw));
1610
1611 /* VSock Domain */
1612
1613 static struct domain *vsock_domain = NULL;
1614
1615 static void
vsock_dinit(struct domain * _Nonnull dp)1616 vsock_dinit(struct domain *_Nonnull dp)
1617 {
1618 // The VSock domain is initialized with a singleton pattern.
1619 VERIFY(dp != NULL);
1620 VERIFY(!(dp->dom_flags & DOM_INITIALIZED));
1621 VERIFY(vsock_domain == NULL);
1622 vsock_domain = dp;
1623
1624 const uint32_t default_buffer_size = VSOCK_MAX_PACKET_SIZE * 8;
1625
1626 // Add protocols and initialize.
1627 for (int i = 0; i < vsock_proto_count; i++) {
1628 vsock_sendspace[i] = default_buffer_size;
1629 vsock_recvspace[i] = default_buffer_size;
1630
1631 net_add_proto((struct protosw *)&vsocksw[i], dp, 1);
1632 }
1633 }
1634
1635 struct domain vsockdomain_s = {
1636 .dom_family = PF_VSOCK,
1637 .dom_name = "vsock",
1638 .dom_init = vsock_dinit,
1639 .dom_maxrtkey = sizeof(struct sockaddr_vm),
1640 .dom_protohdrlen = sizeof(struct sockaddr_vm),
1641 };
1642