xref: /xnu-12377.41.6/tests/vsock_helpers.c (revision bbb1b6f9e71b8cdde6e5cd6f4841f207dee3d828) !
1 /*
2  * Copyright (c) 2024 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <vsock_helpers.h>
30 
31 static int
_vsock_new_socket(uint16_t protocol)32 _vsock_new_socket(uint16_t protocol)
33 {
34 	int sock = socket(AF_VSOCK, SOCK_STREAM, protocol);
35 	if (sock < 0 && errno == ENODEV) {
36 		T_SKIP("no vsock transport available");
37 	}
38 	return sock;
39 }
40 
41 int
vsock_new_socket(void)42 vsock_new_socket(void)
43 {
44 	int sock = _vsock_new_socket(VSOCK_PROTO_STANDARD);
45 	T_ASSERT_GT(sock, 0, "create new vsock socket");
46 	return sock;
47 }
48 
49 int
vsock_private_new_socket(void)50 vsock_private_new_socket(void)
51 {
52 	int sock = _vsock_new_socket(VSOCK_PROTO_PRIVATE);
53 	T_ASSERT_GT(sock, 0, "create new private vsock socket");
54 	return sock;
55 }
56 
57 uint32_t
vsock_get_local_cid(int socket)58 vsock_get_local_cid(int socket)
59 {
60 	uint32_t cid = 0;
61 	int result = ioctl(socket, IOCTL_VM_SOCKETS_GET_LOCAL_CID, &cid);
62 	T_ASSERT_POSIX_SUCCESS(result, "vsock ioctl cid successful");
63 	T_ASSERT_GT(cid, VMADDR_CID_HOST, "cid is set");
64 	T_ASSERT_NE(cid, VMADDR_CID_ANY, "cid is valid");
65 
66 	return cid;
67 }
68 
69 static int
_vsock_bind(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int * socket,uint16_t protocol)70 _vsock_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket, uint16_t protocol)
71 {
72 	int sock = _vsock_new_socket(protocol);
73 	if (sock < 0) {
74 		return sock;
75 	}
76 	T_ASSERT_GT(sock, 0, "create new vsock socket");
77 	*socket = sock;
78 
79 	bzero(addr, sizeof(*addr));
80 	addr->svm_port = port;
81 	addr->svm_cid = cid;
82 
83 	return bind(*socket, (struct sockaddr *) addr, sizeof(*addr));
84 }
85 
86 int
vsock_bind(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int * socket)87 vsock_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket)
88 {
89 	return _vsock_bind(cid, port, addr, socket, VSOCK_PROTO_STANDARD);
90 }
91 
92 int
vsock_private_bind(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int * socket)93 vsock_private_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket)
94 {
95 	return _vsock_bind(cid, port, addr, socket, VSOCK_PROTO_PRIVATE);
96 }
97 
98 int
vsock_listen(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int backlog,int * socket)99 vsock_listen(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int backlog, int *socket)
100 {
101 	int result = vsock_bind(cid, port, addr, socket);
102 	T_ASSERT_POSIX_SUCCESS(result, "vsock bind");
103 	return listen(*socket, backlog);
104 }
105 
106 int
vsock_private_listen(uint32_t cid,uint32_t port,struct sockaddr_vm * addr,int backlog,int * socket)107 vsock_private_listen(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int backlog, int *socket)
108 {
109 	int result = vsock_private_bind(cid, port, addr, socket);
110 	T_ASSERT_POSIX_SUCCESS(result, "private vsock bind");
111 	return listen(*socket, backlog);
112 }
113 
114 static int
_vsock_connect(uint32_t cid,uint32_t port,int * socket,uint16_t protocol)115 _vsock_connect(uint32_t cid, uint32_t port, int *socket, uint16_t protocol)
116 {
117 	int sock = _vsock_new_socket(protocol);
118 	if (sock < 0) {
119 		return sock;
120 	}
121 	T_ASSERT_GT(sock, 0, "create new vsock socket");
122 	*socket = sock;
123 
124 	struct sockaddr_vm addr = (struct sockaddr_vm) {
125 		.svm_cid = cid,
126 		.svm_port = port,
127 	};
128 	return connect(*socket, (struct sockaddr *)&addr, sizeof(addr));
129 }
130 
131 int
vsock_connect(uint32_t cid,uint32_t port,int * socket)132 vsock_connect(uint32_t cid, uint32_t port, int *socket)
133 {
134 	return _vsock_connect(cid, port, socket, VSOCK_PROTO_STANDARD);
135 }
136 
137 int
vsock_private_connect(uint32_t cid,uint32_t port,int * socket)138 vsock_private_connect(uint32_t cid, uint32_t port, int *socket)
139 {
140 	return _vsock_connect(cid, port, socket, VSOCK_PROTO_PRIVATE);
141 }
142 
143 struct sockaddr_vm
vsock_getsockname(int socket)144 vsock_getsockname(int socket)
145 {
146 	struct sockaddr_vm addr;
147 	socklen_t length = sizeof(addr);
148 	int result = getsockname(socket, (struct sockaddr *)&addr, &length);
149 	T_ASSERT_POSIX_SUCCESS(result, "vsock getsockname");
150 	T_ASSERT_EQ_INT((int) sizeof(addr), length, "correct address length");
151 	T_ASSERT_GT(addr.svm_port, 0, "bound to non-zero local port");
152 	return addr;
153 }
154 
155 void
vsock_close(int socket)156 vsock_close(int socket)
157 {
158 	int result = close(socket);
159 	T_ASSERT_POSIX_SUCCESS(result, "vsock close");
160 }
161 
162 void
vsock_connect_peers(uint32_t cid,uint32_t port,int backlog,int * socketA,int * socketB)163 vsock_connect_peers(uint32_t cid, uint32_t port, int backlog, int *socketA, int *socketB)
164 {
165 	// Listen.
166 	struct sockaddr_vm addr;
167 	int listen_socket;
168 	int result = vsock_listen(cid, port, &addr, backlog, &listen_socket);
169 	T_ASSERT_POSIX_SUCCESS(result, "vsock listen");
170 
171 	const uint32_t connection_cid = vsock_get_local_cid(listen_socket);
172 
173 	// Connect.
174 	int connect_socket;
175 	result = vsock_connect(connection_cid, addr.svm_port, &connect_socket);
176 	T_ASSERT_POSIX_SUCCESS(result, "vsock connect");
177 
178 	// Accept.
179 	struct sockaddr_vm accepted_addr;
180 	socklen_t addrlen = sizeof(accepted_addr);
181 	int accepted_socket = accept(listen_socket, (struct sockaddr *)&accepted_addr, &addrlen);
182 	T_ASSERT_GT(accepted_socket, 0, "accepted socket");
183 	T_ASSERT_EQ_INT((int) sizeof(accepted_addr), addrlen, "correct address length");
184 	T_ASSERT_EQ_INT(connection_cid, accepted_addr.svm_cid, "same cid");
185 	T_ASSERT_NE_INT(VMADDR_CID_ANY, accepted_addr.svm_port, "some valid port");
186 	T_ASSERT_NE_INT(0, accepted_addr.svm_port, "some non-zero port");
187 
188 	*socketA = connect_socket;
189 	*socketB = accepted_socket;
190 }
191 
192 void
vsock_send(int socket,char * msg)193 vsock_send(int socket, char *msg)
194 {
195 	T_ASSERT_NOTNULL(msg, "send message is not null");
196 	ssize_t sent_bytes = send(socket, msg, strlen(msg), 0);
197 	T_ASSERT_EQ_LONG(strlen(msg), (unsigned long)sent_bytes, "sent all bytes");
198 }
199 
200 void
vsock_disable_sigpipe(int socket)201 vsock_disable_sigpipe(int socket)
202 {
203 	int on = 1;
204 	int result = setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &on, sizeof(on));
205 	T_ASSERT_POSIX_SUCCESS(result, "vsock disable SIGPIPE");
206 }
207 
208 bool
vsock_address_exists(struct xvsockpgen * buffer,struct sockaddr_vm addr)209 vsock_address_exists(struct xvsockpgen *buffer, struct sockaddr_vm addr)
210 {
211 	struct xvsockpgen *xvg = buffer;
212 	struct xvsockpgen *oxvg = buffer;
213 
214 	bool found = false;
215 	for (xvg = (struct xvsockpgen *)((char *)xvg + xvg->xvg_len);
216 	    xvg->xvg_len > sizeof(struct xvsockpgen);
217 	    xvg = (struct xvsockpgen *)((char *)xvg + xvg->xvg_len)) {
218 		struct xvsockpcb *xpcb = (struct xvsockpcb *)xvg;
219 
220 		/* Ignore PCBs which were freed during copyout. */
221 		if (xpcb->xvp_gencnt > oxvg->xvg_gen) {
222 			continue;
223 		}
224 
225 		if (xpcb->xvp_local_cid == addr.svm_cid && xpcb->xvp_remote_cid == VMADDR_CID_ANY &&
226 		    xpcb->xvp_local_port == addr.svm_port && xpcb->xvp_remote_port == VMADDR_PORT_ANY) {
227 			found = true;
228 			break;
229 		}
230 	}
231 
232 	T_ASSERT_NE(xvg, oxvg, "first and last xvsockpgen were returned");
233 
234 	return found;
235 }
236 
237 uint32_t
vsock_get_available_port(void)238 vsock_get_available_port(void)
239 {
240 	int socket;
241 	struct sockaddr_vm addr;
242 	int result = vsock_bind(VMADDR_CID_ANY, VMADDR_PORT_ANY, &addr, &socket);
243 	T_ASSERT_POSIX_SUCCESS(result, "vsock bind to any port");
244 
245 	const struct sockaddr_vm bound_addr = vsock_getsockname(socket);
246 	const uint32_t port = bound_addr.svm_port;
247 	T_ASSERT_NE_INT(port, VMADDR_PORT_ANY, "port is specified");
248 
249 	vsock_close(socket);
250 
251 	return port;
252 }
253 
254 int
vsock_bind_family(sa_family_t family)255 vsock_bind_family(sa_family_t family)
256 {
257 	int socket = vsock_new_socket();
258 	const uint32_t port = vsock_get_available_port();
259 
260 	struct sockaddr_vm addr = (struct sockaddr_vm) {
261 		.svm_family = family,
262 		.svm_cid = VMADDR_CID_ANY,
263 		.svm_port = port,
264 	};
265 
266 	return bind(socket, (struct sockaddr *) &addr, sizeof(addr));
267 }
268