1 /*
2 * Copyright (c) 2024 Apple Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28
29 #include <vsock_helpers.h>
30
31 static int
_vsock_new_socket(uint16_t protocol)32 _vsock_new_socket(uint16_t protocol)
33 {
34 int sock = socket(AF_VSOCK, SOCK_STREAM, protocol);
35 if (sock < 0 && errno == ENODEV) {
36 T_SKIP("no vsock transport available");
37 }
38 return sock;
39 }
40
41 int
vsock_new_socket(void)42 vsock_new_socket(void)
43 {
44 int sock = _vsock_new_socket(VSOCK_PROTO_STANDARD);
45 T_ASSERT_GT(sock, 0, "create new vsock socket");
46 return sock;
47 }
48
49 int
vsock_private_new_socket(void)50 vsock_private_new_socket(void)
51 {
52 int sock = _vsock_new_socket(VSOCK_PROTO_PRIVATE);
53 T_ASSERT_GT(sock, 0, "create new private vsock socket");
54 return sock;
55 }
56
57 uint32_t
vsock_get_local_cid(int socket)58 vsock_get_local_cid(int socket)
59 {
60 uint32_t cid = 0;
61 int result = ioctl(socket, IOCTL_VM_SOCKETS_GET_LOCAL_CID, &cid);
62 T_ASSERT_POSIX_SUCCESS(result, "vsock ioctl cid successful");
63 T_ASSERT_GT(cid, VMADDR_CID_HOST, "cid is set");
64 T_ASSERT_NE(cid, VMADDR_CID_ANY, "cid is valid");
65
66 return cid;
67 }
68
69 static int
_vsock_bind(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int * socket,uint16_t protocol)70 _vsock_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket, uint16_t protocol)
71 {
72 int sock = _vsock_new_socket(protocol);
73 if (sock < 0) {
74 return sock;
75 }
76 T_ASSERT_GT(sock, 0, "create new vsock socket");
77 *socket = sock;
78
79 bzero(addr, sizeof(*addr));
80 addr->svm_port = port;
81 addr->svm_cid = cid;
82
83 return bind(*socket, (struct sockaddr *) addr, sizeof(*addr));
84 }
85
86 int
vsock_bind(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int * socket)87 vsock_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket)
88 {
89 return _vsock_bind(cid, port, addr, socket, VSOCK_PROTO_STANDARD);
90 }
91
92 int
vsock_private_bind(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int * socket)93 vsock_private_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket)
94 {
95 return _vsock_bind(cid, port, addr, socket, VSOCK_PROTO_PRIVATE);
96 }
97
98 int
vsock_listen(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int backlog,int * socket)99 vsock_listen(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int backlog, int *socket)
100 {
101 int result = vsock_bind(cid, port, addr, socket);
102 T_ASSERT_POSIX_SUCCESS(result, "vsock bind");
103 return listen(*socket, backlog);
104 }
105
106 int
vsock_private_listen(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int backlog,int * socket)107 vsock_private_listen(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int backlog, int *socket)
108 {
109 int result = vsock_private_bind(cid, port, addr, socket);
110 T_ASSERT_POSIX_SUCCESS(result, "private vsock bind");
111 return listen(*socket, backlog);
112 }
113
114 static int
_vsock_connect(uint32_t cid,uint32_t port,int * socket,uint16_t protocol)115 _vsock_connect(uint32_t cid, uint32_t port, int *socket, uint16_t protocol)
116 {
117 int sock = _vsock_new_socket(protocol);
118 if (sock < 0) {
119 return sock;
120 }
121 T_ASSERT_GT(sock, 0, "create new vsock socket");
122 *socket = sock;
123
124 struct sockaddr_vm addr = (struct sockaddr_vm) {
125 .svm_cid = cid,
126 .svm_port = port,
127 };
128 return connect(*socket, (struct sockaddr *)&addr, sizeof(addr));
129 }
130
131 int
vsock_connect(uint32_t cid,uint32_t port,int * socket)132 vsock_connect(uint32_t cid, uint32_t port, int *socket)
133 {
134 return _vsock_connect(cid, port, socket, VSOCK_PROTO_STANDARD);
135 }
136
137 int
vsock_private_connect(uint32_t cid,uint32_t port,int * socket)138 vsock_private_connect(uint32_t cid, uint32_t port, int *socket)
139 {
140 return _vsock_connect(cid, port, socket, VSOCK_PROTO_PRIVATE);
141 }
142
143 struct sockaddr_vm
vsock_getsockname(int socket)144 vsock_getsockname(int socket)
145 {
146 struct sockaddr_vm addr;
147 socklen_t length = sizeof(addr);
148 int result = getsockname(socket, (struct sockaddr *)&addr, &length);
149 T_ASSERT_POSIX_SUCCESS(result, "vsock getsockname");
150 T_ASSERT_EQ_INT((int) sizeof(addr), length, "correct address length");
151 T_ASSERT_GT(addr.svm_port, 0, "bound to non-zero local port");
152 return addr;
153 }
154
155 void
vsock_close(int socket)156 vsock_close(int socket)
157 {
158 int result = close(socket);
159 T_ASSERT_POSIX_SUCCESS(result, "vsock close");
160 }
161
162 void
vsock_connect_peers(uint32_t cid,uint32_t port,int backlog,int * socketA,int * socketB)163 vsock_connect_peers(uint32_t cid, uint32_t port, int backlog, int *socketA, int *socketB)
164 {
165 // Listen.
166 struct sockaddr_vm addr;
167 int listen_socket;
168 int result = vsock_listen(cid, port, &addr, backlog, &listen_socket);
169 T_ASSERT_POSIX_SUCCESS(result, "vsock listen");
170
171 const uint32_t connection_cid = vsock_get_local_cid(listen_socket);
172
173 // Connect.
174 int connect_socket;
175 result = vsock_connect(connection_cid, addr.svm_port, &connect_socket);
176 T_ASSERT_POSIX_SUCCESS(result, "vsock connect");
177
178 // Accept.
179 struct sockaddr_vm accepted_addr;
180 socklen_t addrlen = sizeof(accepted_addr);
181 int accepted_socket = accept(listen_socket, (struct sockaddr *)&accepted_addr, &addrlen);
182 T_ASSERT_GT(accepted_socket, 0, "accepted socket");
183 T_ASSERT_EQ_INT((int) sizeof(accepted_addr), addrlen, "correct address length");
184 T_ASSERT_EQ_INT(connection_cid, accepted_addr.svm_cid, "same cid");
185 T_ASSERT_NE_INT(VMADDR_CID_ANY, accepted_addr.svm_port, "some valid port");
186 T_ASSERT_NE_INT(0, accepted_addr.svm_port, "some non-zero port");
187
188 *socketA = connect_socket;
189 *socketB = accepted_socket;
190 }
191
192 void
vsock_send(int socket,char * msg)193 vsock_send(int socket, char *msg)
194 {
195 T_ASSERT_NOTNULL(msg, "send message is not null");
196 ssize_t sent_bytes = send(socket, msg, strlen(msg), 0);
197 T_ASSERT_EQ_LONG(strlen(msg), (unsigned long)sent_bytes, "sent all bytes");
198 }
199
200 void
vsock_disable_sigpipe(int socket)201 vsock_disable_sigpipe(int socket)
202 {
203 int on = 1;
204 int result = setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &on, sizeof(on));
205 T_ASSERT_POSIX_SUCCESS(result, "vsock disable SIGPIPE");
206 }
207
208 bool
vsock_address_exists(struct xvsockpgen * buffer,struct sockaddr_vm addr)209 vsock_address_exists(struct xvsockpgen *buffer, struct sockaddr_vm addr)
210 {
211 struct xvsockpgen *xvg = buffer;
212 struct xvsockpgen *oxvg = buffer;
213
214 bool found = false;
215 for (xvg = (struct xvsockpgen *)((char *)xvg + xvg->xvg_len);
216 xvg->xvg_len > sizeof(struct xvsockpgen);
217 xvg = (struct xvsockpgen *)((char *)xvg + xvg->xvg_len)) {
218 struct xvsockpcb *xpcb = (struct xvsockpcb *)xvg;
219
220 /* Ignore PCBs which were freed during copyout. */
221 if (xpcb->xvp_gencnt > oxvg->xvg_gen) {
222 continue;
223 }
224
225 if (xpcb->xvp_local_cid == addr.svm_cid && xpcb->xvp_remote_cid == VMADDR_CID_ANY &&
226 xpcb->xvp_local_port == addr.svm_port && xpcb->xvp_remote_port == VMADDR_PORT_ANY) {
227 found = true;
228 break;
229 }
230 }
231
232 T_ASSERT_NE(xvg, oxvg, "first and last xvsockpgen were returned");
233
234 return found;
235 }
236
237 uint32_t
vsock_get_available_port(void)238 vsock_get_available_port(void)
239 {
240 int socket;
241 struct sockaddr_vm addr;
242 int result = vsock_bind(VMADDR_CID_ANY, VMADDR_PORT_ANY, &addr, &socket);
243 T_ASSERT_POSIX_SUCCESS(result, "vsock bind to any port");
244
245 const struct sockaddr_vm bound_addr = vsock_getsockname(socket);
246 const uint32_t port = bound_addr.svm_port;
247 T_ASSERT_NE_INT(port, VMADDR_PORT_ANY, "port is specified");
248
249 vsock_close(socket);
250
251 return port;
252 }
253
254 int
vsock_bind_family(sa_family_t family)255 vsock_bind_family(sa_family_t family)
256 {
257 int socket = vsock_new_socket();
258 const uint32_t port = vsock_get_available_port();
259
260 struct sockaddr_vm addr = (struct sockaddr_vm) {
261 .svm_family = family,
262 .svm_cid = VMADDR_CID_ANY,
263 .svm_port = port,
264 };
265
266 return bind(socket, (struct sockaddr *) &addr, sizeof(addr));
267 }
268