xref: /xnu-8019.80.24/bsd/netinet/flow_divert.c (revision a325d9c4a84054e40bbe985afedcb50ab80993ea)
1 /*
2  * Copyright (c) 2012-2021 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <string.h>
30 #include <sys/types.h>
31 #include <sys/syslog.h>
32 #include <sys/queue.h>
33 #include <sys/malloc.h>
34 #include <sys/socket.h>
35 #include <sys/kpi_mbuf.h>
36 #include <sys/mbuf.h>
37 #include <sys/domain.h>
38 #include <sys/protosw.h>
39 #include <sys/socketvar.h>
40 #include <sys/kernel.h>
41 #include <sys/systm.h>
42 #include <sys/kern_control.h>
43 #include <sys/ubc.h>
44 #include <sys/codesign.h>
45 #include <libkern/tree.h>
46 #include <kern/locks.h>
47 #include <kern/debug.h>
48 #include <kern/task.h>
49 #include <mach/task_info.h>
50 #include <net/if_var.h>
51 #include <net/route.h>
52 #include <net/flowhash.h>
53 #include <net/ntstat.h>
54 #include <net/content_filter.h>
55 #include <net/necp.h>
56 #include <netinet/in.h>
57 #include <netinet/in_var.h>
58 #include <netinet/tcp.h>
59 #include <netinet/tcp_var.h>
60 #include <netinet/tcp_fsm.h>
61 #include <netinet/flow_divert.h>
62 #include <netinet/flow_divert_proto.h>
63 #include <netinet6/in6_pcb.h>
64 #include <netinet6/ip6protosw.h>
65 #include <dev/random/randomdev.h>
66 #include <libkern/crypto/sha1.h>
67 #include <libkern/crypto/crypto_internal.h>
68 #include <os/log.h>
69 #include <corecrypto/cc.h>
70 #if CONTENT_FILTER
71 #include <net/content_filter.h>
72 #endif /* CONTENT_FILTER */
73 
74 #define FLOW_DIVERT_CONNECT_STARTED             0x00000001
75 #define FLOW_DIVERT_READ_CLOSED                 0x00000002
76 #define FLOW_DIVERT_WRITE_CLOSED                0x00000004
77 #define FLOW_DIVERT_TUNNEL_RD_CLOSED    0x00000008
78 #define FLOW_DIVERT_TUNNEL_WR_CLOSED    0x00000010
79 #define FLOW_DIVERT_HAS_HMAC            0x00000040
80 #define FLOW_DIVERT_NOTIFY_ON_RECEIVED  0x00000080
81 #define FLOW_DIVERT_IMPLICIT_CONNECT    0x00000100
82 #define FLOW_DIVERT_DID_SET_LOCAL_ADDR  0x00000200
83 #define FLOW_DIVERT_HAS_TOKEN           0x00000400
84 #define FLOW_DIVERT_SHOULD_SET_LOCAL_ADDR 0x00000800
85 #define FLOW_DIVERT_FLOW_IS_TRANSPARENT   0x00001000
86 
87 #define FDLOG(level, pcb, format, ...) \
88 	os_log_with_type(OS_LOG_DEFAULT, flow_divert_syslog_type_to_oslog_type(level), "(%u): " format "\n", (pcb)->hash, __VA_ARGS__)
89 
90 #define FDLOG0(level, pcb, msg) \
91 	os_log_with_type(OS_LOG_DEFAULT, flow_divert_syslog_type_to_oslog_type(level), "(%u): " msg "\n", (pcb)->hash)
92 
93 #define FDRETAIN(pcb)                   if ((pcb) != NULL) OSIncrementAtomic(&(pcb)->ref_count)
94 #define FDRELEASE(pcb)                                                                                                          \
95 	do {                                                                                                                                    \
96 	        if ((pcb) != NULL && 1 == OSDecrementAtomic(&(pcb)->ref_count)) {       \
97 	                flow_divert_pcb_destroy(pcb);                                                                   \
98 	        }                                                                                                                                       \
99 	} while (0)
100 
101 #define FDLOCK(pcb)                                             lck_mtx_lock(&(pcb)->mtx)
102 #define FDUNLOCK(pcb)                                   lck_mtx_unlock(&(pcb)->mtx)
103 
104 #define FD_CTL_SENDBUFF_SIZE                    (128 * 1024)
105 #define FD_CTL_RCVBUFF_SIZE                             (128 * 1024)
106 
107 #define GROUP_BIT_CTL_ENQUEUE_BLOCKED   0
108 
109 #define GROUP_COUNT_MAX                                 31
110 #define FLOW_DIVERT_MAX_NAME_SIZE               4096
111 #define FLOW_DIVERT_MAX_KEY_SIZE                1024
112 #define FLOW_DIVERT_MAX_TRIE_MEMORY             (1024 * 1024)
113 
114 struct flow_divert_trie_node {
115 	uint16_t start;
116 	uint16_t length;
117 	uint16_t child_map;
118 };
119 
120 #define CHILD_MAP_SIZE                  256
121 #define NULL_TRIE_IDX                   0xffff
122 #define TRIE_NODE(t, i)                 ((t)->nodes[(i)])
123 #define TRIE_CHILD(t, i, b)             (((t)->child_maps + (CHILD_MAP_SIZE * TRIE_NODE(t, i).child_map))[(b)])
124 #define TRIE_BYTE(t, i)                 ((t)->bytes[(i)])
125 
126 static struct flow_divert_pcb           nil_pcb;
127 
128 static LCK_ATTR_DECLARE(flow_divert_mtx_attr, 0, 0);
129 static LCK_GRP_DECLARE(flow_divert_mtx_grp, FLOW_DIVERT_CONTROL_NAME);
130 static LCK_RW_DECLARE_ATTR(g_flow_divert_group_lck, &flow_divert_mtx_grp,
131     &flow_divert_mtx_attr);
132 
133 static struct flow_divert_group         **g_flow_divert_groups  = NULL;
134 static uint32_t                         g_active_group_count    = 0;
135 
136 static  errno_t                         g_init_result           = 0;
137 
138 static  kern_ctl_ref                    g_flow_divert_kctl_ref  = NULL;
139 
140 static struct protosw                   g_flow_divert_in_protosw;
141 static struct pr_usrreqs                g_flow_divert_in_usrreqs;
142 static struct protosw                   g_flow_divert_in_udp_protosw;
143 static struct pr_usrreqs                g_flow_divert_in_udp_usrreqs;
144 static struct ip6protosw                g_flow_divert_in6_protosw;
145 static struct pr_usrreqs                g_flow_divert_in6_usrreqs;
146 static struct ip6protosw                g_flow_divert_in6_udp_protosw;
147 static struct pr_usrreqs                g_flow_divert_in6_udp_usrreqs;
148 
149 static struct protosw                   *g_tcp_protosw          = NULL;
150 static struct ip6protosw                *g_tcp6_protosw         = NULL;
151 static struct protosw                   *g_udp_protosw          = NULL;
152 static struct ip6protosw                *g_udp6_protosw         = NULL;
153 
154 ZONE_DECLARE(flow_divert_group_zone, "flow_divert_group",
155     sizeof(struct flow_divert_group), ZC_ZFREE_CLEARMEM);
156 ZONE_DECLARE(flow_divert_pcb_zone, "flow_divert_pcb",
157     sizeof(struct flow_divert_pcb), ZC_ZFREE_CLEARMEM);
158 
159 static errno_t
160 flow_divert_dup_addr(sa_family_t family, struct sockaddr *addr, struct sockaddr **dup);
161 
162 static boolean_t
163 flow_divert_is_sockaddr_valid(struct sockaddr *addr);
164 
165 static int
166 flow_divert_append_target_endpoint_tlv(mbuf_t connect_packet, struct sockaddr *toaddr);
167 
168 struct sockaddr *
169 flow_divert_get_buffered_target_address(mbuf_t buffer);
170 
171 static void
172 flow_divert_disconnect_socket(struct socket *so, bool is_connected);
173 
174 static inline uint8_t
flow_divert_syslog_type_to_oslog_type(int syslog_type)175 flow_divert_syslog_type_to_oslog_type(int syslog_type)
176 {
177 	switch (syslog_type) {
178 	case LOG_ERR: return OS_LOG_TYPE_ERROR;
179 	case LOG_INFO: return OS_LOG_TYPE_INFO;
180 	case LOG_DEBUG: return OS_LOG_TYPE_DEBUG;
181 	default: return OS_LOG_TYPE_DEFAULT;
182 	}
183 }
184 
185 static inline int
flow_divert_pcb_cmp(const struct flow_divert_pcb * pcb_a,const struct flow_divert_pcb * pcb_b)186 flow_divert_pcb_cmp(const struct flow_divert_pcb *pcb_a, const struct flow_divert_pcb *pcb_b)
187 {
188 	return memcmp(&pcb_a->hash, &pcb_b->hash, sizeof(pcb_a->hash));
189 }
190 
191 RB_PROTOTYPE(fd_pcb_tree, flow_divert_pcb, rb_link, flow_divert_pcb_cmp);
192 RB_GENERATE(fd_pcb_tree, flow_divert_pcb, rb_link, flow_divert_pcb_cmp);
193 
194 static const char *
flow_divert_packet_type2str(uint8_t packet_type)195 flow_divert_packet_type2str(uint8_t packet_type)
196 {
197 	switch (packet_type) {
198 	case FLOW_DIVERT_PKT_CONNECT:
199 		return "connect";
200 	case FLOW_DIVERT_PKT_CONNECT_RESULT:
201 		return "connect result";
202 	case FLOW_DIVERT_PKT_DATA:
203 		return "data";
204 	case FLOW_DIVERT_PKT_CLOSE:
205 		return "close";
206 	case FLOW_DIVERT_PKT_READ_NOTIFY:
207 		return "read notification";
208 	case FLOW_DIVERT_PKT_PROPERTIES_UPDATE:
209 		return "properties update";
210 	case FLOW_DIVERT_PKT_APP_MAP_CREATE:
211 		return "app map create";
212 	default:
213 		return "unknown";
214 	}
215 }
216 
217 static struct flow_divert_pcb *
flow_divert_pcb_lookup(uint32_t hash,struct flow_divert_group * group)218 flow_divert_pcb_lookup(uint32_t hash, struct flow_divert_group *group)
219 {
220 	struct flow_divert_pcb  key_item;
221 	struct flow_divert_pcb  *fd_cb          = NULL;
222 
223 	key_item.hash = hash;
224 
225 	lck_rw_lock_shared(&group->lck);
226 	fd_cb = RB_FIND(fd_pcb_tree, &group->pcb_tree, &key_item);
227 	FDRETAIN(fd_cb);
228 	lck_rw_done(&group->lck);
229 
230 	return fd_cb;
231 }
232 
233 static errno_t
flow_divert_pcb_insert(struct flow_divert_pcb * fd_cb,uint32_t ctl_unit)234 flow_divert_pcb_insert(struct flow_divert_pcb *fd_cb, uint32_t ctl_unit)
235 {
236 	errno_t                                                 error                                           = 0;
237 	struct                                          flow_divert_pcb *exist          = NULL;
238 	struct flow_divert_group        *group;
239 	static uint32_t                         g_nextkey                                       = 1;
240 	static uint32_t                         g_hash_seed                                     = 0;
241 	int                                                     try_count                                       = 0;
242 
243 	if (ctl_unit == 0 || ctl_unit >= GROUP_COUNT_MAX) {
244 		return EINVAL;
245 	}
246 
247 	socket_unlock(fd_cb->so, 0);
248 	lck_rw_lock_shared(&g_flow_divert_group_lck);
249 
250 	if (g_flow_divert_groups == NULL || g_active_group_count == 0) {
251 		FDLOG0(LOG_ERR, &nil_pcb, "No active groups, flow divert cannot be used for this socket");
252 		error = ENETUNREACH;
253 		goto done;
254 	}
255 
256 	group = g_flow_divert_groups[ctl_unit];
257 	if (group == NULL) {
258 		FDLOG(LOG_ERR, &nil_pcb, "Group for control unit %u is NULL, flow divert cannot be used for this socket", ctl_unit);
259 		error = ENETUNREACH;
260 		goto done;
261 	}
262 
263 	socket_lock(fd_cb->so, 0);
264 
265 	do {
266 		uint32_t        key[2];
267 		uint32_t        idx;
268 
269 		key[0] = g_nextkey++;
270 		key[1] = RandomULong();
271 
272 		if (g_hash_seed == 0) {
273 			g_hash_seed = RandomULong();
274 		}
275 
276 		fd_cb->hash = net_flowhash(key, sizeof(key), g_hash_seed);
277 
278 		for (idx = 1; idx < GROUP_COUNT_MAX; idx++) {
279 			struct flow_divert_group *curr_group = g_flow_divert_groups[idx];
280 			if (curr_group != NULL && curr_group != group) {
281 				lck_rw_lock_shared(&curr_group->lck);
282 				exist = RB_FIND(fd_pcb_tree, &curr_group->pcb_tree, fd_cb);
283 				lck_rw_done(&curr_group->lck);
284 				if (exist != NULL) {
285 					break;
286 				}
287 			}
288 		}
289 
290 		if (exist == NULL) {
291 			lck_rw_lock_exclusive(&group->lck);
292 			exist = RB_INSERT(fd_pcb_tree, &group->pcb_tree, fd_cb);
293 			lck_rw_done(&group->lck);
294 		}
295 	} while (exist != NULL && try_count++ < 3);
296 
297 	if (exist == NULL) {
298 		fd_cb->group = group;
299 		FDRETAIN(fd_cb);                /* The group now has a reference */
300 	} else {
301 		fd_cb->hash = 0;
302 		error = EEXIST;
303 	}
304 
305 	socket_unlock(fd_cb->so, 0);
306 
307 done:
308 	lck_rw_done(&g_flow_divert_group_lck);
309 	socket_lock(fd_cb->so, 0);
310 
311 	return error;
312 }
313 
314 static struct flow_divert_pcb *
flow_divert_pcb_create(socket_t so)315 flow_divert_pcb_create(socket_t so)
316 {
317 	struct flow_divert_pcb  *new_pcb = NULL;
318 
319 	new_pcb = zalloc_flags(flow_divert_pcb_zone, Z_WAITOK | Z_ZERO);
320 	lck_mtx_init(&new_pcb->mtx, &flow_divert_mtx_grp, &flow_divert_mtx_attr);
321 	new_pcb->so = so;
322 	new_pcb->log_level = nil_pcb.log_level;
323 
324 	FDRETAIN(new_pcb);      /* Represents the socket's reference */
325 
326 	return new_pcb;
327 }
328 
329 static void
flow_divert_pcb_destroy(struct flow_divert_pcb * fd_cb)330 flow_divert_pcb_destroy(struct flow_divert_pcb *fd_cb)
331 {
332 	FDLOG(LOG_INFO, fd_cb, "Destroying, app tx %u, tunnel tx %u, tunnel rx %u",
333 	    fd_cb->bytes_written_by_app, fd_cb->bytes_sent, fd_cb->bytes_received);
334 
335 	if (fd_cb->connect_token != NULL) {
336 		mbuf_freem(fd_cb->connect_token);
337 	}
338 	if (fd_cb->connect_packet != NULL) {
339 		mbuf_freem(fd_cb->connect_packet);
340 	}
341 	if (fd_cb->app_data != NULL) {
342 		kfree_data(fd_cb->app_data, fd_cb->app_data_length);
343 	}
344 	free_sockaddr(fd_cb->original_remote_endpoint);
345 	zfree(flow_divert_pcb_zone, fd_cb);
346 }
347 
348 static void
flow_divert_pcb_remove(struct flow_divert_pcb * fd_cb)349 flow_divert_pcb_remove(struct flow_divert_pcb *fd_cb)
350 {
351 	if (fd_cb->group != NULL) {
352 		struct flow_divert_group *group = fd_cb->group;
353 		lck_rw_lock_exclusive(&group->lck);
354 		FDLOG(LOG_INFO, fd_cb, "Removing from group %d, ref count = %d", group->ctl_unit, fd_cb->ref_count);
355 		RB_REMOVE(fd_pcb_tree, &group->pcb_tree, fd_cb);
356 		fd_cb->group = NULL;
357 		FDRELEASE(fd_cb);                               /* Release the group's reference */
358 		lck_rw_done(&group->lck);
359 	}
360 }
361 
362 static int
flow_divert_packet_init(struct flow_divert_pcb * fd_cb,uint8_t packet_type,mbuf_t * packet)363 flow_divert_packet_init(struct flow_divert_pcb *fd_cb, uint8_t packet_type, mbuf_t *packet)
364 {
365 	struct flow_divert_packet_header        hdr;
366 	int                                     error           = 0;
367 
368 	error = mbuf_gethdr(MBUF_DONTWAIT, MBUF_TYPE_HEADER, packet);
369 	if (error) {
370 		FDLOG(LOG_ERR, fd_cb, "failed to allocate the header mbuf: %d", error);
371 		return error;
372 	}
373 
374 	hdr.packet_type = packet_type;
375 	hdr.conn_id = htonl(fd_cb->hash);
376 
377 	/* Lay down the header */
378 	error = mbuf_copyback(*packet, 0, sizeof(hdr), &hdr, MBUF_DONTWAIT);
379 	if (error) {
380 		FDLOG(LOG_ERR, fd_cb, "mbuf_copyback(hdr) failed: %d", error);
381 		mbuf_freem(*packet);
382 		*packet = NULL;
383 		return error;
384 	}
385 
386 	return 0;
387 }
388 
389 static int
flow_divert_packet_append_tlv(mbuf_t packet,uint8_t type,uint32_t length,const void * value)390 flow_divert_packet_append_tlv(mbuf_t packet, uint8_t type, uint32_t length, const void *value)
391 {
392 	uint32_t        net_length      = htonl(length);
393 	int                     error           = 0;
394 
395 	error = mbuf_copyback(packet, mbuf_pkthdr_len(packet), sizeof(type), &type, MBUF_DONTWAIT);
396 	if (error) {
397 		FDLOG(LOG_ERR, &nil_pcb, "failed to append the type (%d)", type);
398 		return error;
399 	}
400 
401 	error = mbuf_copyback(packet, mbuf_pkthdr_len(packet), sizeof(net_length), &net_length, MBUF_DONTWAIT);
402 	if (error) {
403 		FDLOG(LOG_ERR, &nil_pcb, "failed to append the length (%u)", length);
404 		return error;
405 	}
406 
407 	error = mbuf_copyback(packet, mbuf_pkthdr_len(packet), length, value, MBUF_DONTWAIT);
408 	if (error) {
409 		FDLOG0(LOG_ERR, &nil_pcb, "failed to append the value");
410 		return error;
411 	}
412 
413 	return error;
414 }
415 
416 static int
flow_divert_packet_find_tlv(mbuf_t packet,int offset,uint8_t type,int * err,int next)417 flow_divert_packet_find_tlv(mbuf_t packet, int offset, uint8_t type, int *err, int next)
418 {
419 	size_t          cursor                  = offset;
420 	int                     error                   = 0;
421 	uint32_t        curr_length;
422 	uint8_t         curr_type;
423 
424 	*err = 0;
425 
426 	do {
427 		if (!next) {
428 			error = mbuf_copydata(packet, cursor, sizeof(curr_type), &curr_type);
429 			if (error) {
430 				*err = ENOENT;
431 				return -1;
432 			}
433 		} else {
434 			next = 0;
435 			curr_type = FLOW_DIVERT_TLV_NIL;
436 		}
437 
438 		if (curr_type != type) {
439 			cursor += sizeof(curr_type);
440 			error = mbuf_copydata(packet, cursor, sizeof(curr_length), &curr_length);
441 			if (error) {
442 				*err = error;
443 				return -1;
444 			}
445 
446 			cursor += (sizeof(curr_length) + ntohl(curr_length));
447 		}
448 	} while (curr_type != type);
449 
450 	return (int)cursor;
451 }
452 
453 static int
flow_divert_packet_get_tlv(mbuf_t packet,int offset,uint8_t type,size_t buff_len,void * buff,uint32_t * val_size)454 flow_divert_packet_get_tlv(mbuf_t packet, int offset, uint8_t type, size_t buff_len, void *buff, uint32_t *val_size)
455 {
456 	int                     error           = 0;
457 	uint32_t        length;
458 	int                     tlv_offset;
459 
460 	tlv_offset = flow_divert_packet_find_tlv(packet, offset, type, &error, 0);
461 	if (tlv_offset < 0) {
462 		return error;
463 	}
464 
465 	error = mbuf_copydata(packet, tlv_offset + sizeof(type), sizeof(length), &length);
466 	if (error) {
467 		return error;
468 	}
469 
470 	length = ntohl(length);
471 
472 	uint32_t data_offset = tlv_offset + sizeof(type) + sizeof(length);
473 
474 	if (length > (mbuf_pkthdr_len(packet) - data_offset)) {
475 		FDLOG(LOG_ERR, &nil_pcb, "Length of %u TLV (%u) is larger than remaining packet data (%lu)", type, length, (mbuf_pkthdr_len(packet) - data_offset));
476 		return EINVAL;
477 	}
478 
479 	if (val_size != NULL) {
480 		*val_size = length;
481 	}
482 
483 	if (buff != NULL && buff_len > 0) {
484 		memset(buff, 0, buff_len);
485 		size_t to_copy = (length < buff_len) ? length : buff_len;
486 		error = mbuf_copydata(packet, data_offset, to_copy, buff);
487 		if (error) {
488 			return error;
489 		}
490 	}
491 
492 	return 0;
493 }
494 
495 static int
flow_divert_packet_compute_hmac(mbuf_t packet,struct flow_divert_group * group,uint8_t * hmac)496 flow_divert_packet_compute_hmac(mbuf_t packet, struct flow_divert_group *group, uint8_t *hmac)
497 {
498 	mbuf_t  curr_mbuf       = packet;
499 
500 	if (g_crypto_funcs == NULL || group->token_key == NULL) {
501 		return ENOPROTOOPT;
502 	}
503 
504 	cchmac_di_decl(g_crypto_funcs->ccsha1_di, hmac_ctx);
505 	g_crypto_funcs->cchmac_init_fn(g_crypto_funcs->ccsha1_di, hmac_ctx, group->token_key_size, group->token_key);
506 
507 	while (curr_mbuf != NULL) {
508 		g_crypto_funcs->cchmac_update_fn(g_crypto_funcs->ccsha1_di, hmac_ctx, mbuf_len(curr_mbuf), mbuf_data(curr_mbuf));
509 		curr_mbuf = mbuf_next(curr_mbuf);
510 	}
511 
512 	g_crypto_funcs->cchmac_final_fn(g_crypto_funcs->ccsha1_di, hmac_ctx, hmac);
513 
514 	return 0;
515 }
516 
517 static int
flow_divert_packet_verify_hmac(mbuf_t packet,uint32_t ctl_unit)518 flow_divert_packet_verify_hmac(mbuf_t packet, uint32_t ctl_unit)
519 {
520 	int                                                     error = 0;
521 	struct flow_divert_group        *group = NULL;
522 	int                                                     hmac_offset;
523 	uint8_t                                         packet_hmac[SHA_DIGEST_LENGTH];
524 	uint8_t                                         computed_hmac[SHA_DIGEST_LENGTH];
525 	mbuf_t                                          tail;
526 
527 	lck_rw_lock_shared(&g_flow_divert_group_lck);
528 
529 	if (g_flow_divert_groups != NULL && g_active_group_count > 0) {
530 		group = g_flow_divert_groups[ctl_unit];
531 	}
532 
533 	if (group == NULL) {
534 		lck_rw_done(&g_flow_divert_group_lck);
535 		return ENOPROTOOPT;
536 	}
537 
538 	lck_rw_lock_shared(&group->lck);
539 
540 	if (group->token_key == NULL) {
541 		error = ENOPROTOOPT;
542 		goto done;
543 	}
544 
545 	hmac_offset = flow_divert_packet_find_tlv(packet, 0, FLOW_DIVERT_TLV_HMAC, &error, 0);
546 	if (hmac_offset < 0) {
547 		goto done;
548 	}
549 
550 	error = flow_divert_packet_get_tlv(packet, hmac_offset, FLOW_DIVERT_TLV_HMAC, sizeof(packet_hmac), packet_hmac, NULL);
551 	if (error) {
552 		goto done;
553 	}
554 
555 	/* Chop off the HMAC TLV */
556 	error = mbuf_split(packet, hmac_offset, MBUF_WAITOK, &tail);
557 	if (error) {
558 		goto done;
559 	}
560 
561 	mbuf_free(tail);
562 
563 	error = flow_divert_packet_compute_hmac(packet, group, computed_hmac);
564 	if (error) {
565 		goto done;
566 	}
567 
568 	if (cc_cmp_safe(sizeof(packet_hmac), packet_hmac, computed_hmac)) {
569 		FDLOG0(LOG_WARNING, &nil_pcb, "HMAC in token does not match computed HMAC");
570 		error = EINVAL;
571 		goto done;
572 	}
573 
574 done:
575 	lck_rw_done(&group->lck);
576 	lck_rw_done(&g_flow_divert_group_lck);
577 	return error;
578 }
579 
580 static void
flow_divert_add_data_statistics(struct flow_divert_pcb * fd_cb,size_t data_len,Boolean send)581 flow_divert_add_data_statistics(struct flow_divert_pcb *fd_cb, size_t data_len, Boolean send)
582 {
583 	struct inpcb *inp = NULL;
584 	struct ifnet *ifp = NULL;
585 	Boolean cell = FALSE;
586 	Boolean wifi = FALSE;
587 	Boolean wired = FALSE;
588 
589 	inp = sotoinpcb(fd_cb->so);
590 	if (inp == NULL) {
591 		return;
592 	}
593 
594 	if (inp->inp_vflag & INP_IPV4) {
595 		ifp = inp->inp_last_outifp;
596 	} else if (inp->inp_vflag & INP_IPV6) {
597 		ifp = inp->in6p_last_outifp;
598 	}
599 	if (ifp != NULL) {
600 		cell = IFNET_IS_CELLULAR(ifp);
601 		wifi = (!cell && IFNET_IS_WIFI(ifp));
602 		wired = (!wifi && IFNET_IS_WIRED(ifp));
603 	}
604 
605 	if (send) {
606 		INP_ADD_STAT(inp, cell, wifi, wired, txpackets, 1);
607 		INP_ADD_STAT(inp, cell, wifi, wired, txbytes, data_len);
608 	} else {
609 		INP_ADD_STAT(inp, cell, wifi, wired, rxpackets, 1);
610 		INP_ADD_STAT(inp, cell, wifi, wired, rxbytes, data_len);
611 	}
612 	inp_set_activity_bitmap(inp);
613 }
614 
615 static errno_t
flow_divert_check_no_cellular(struct flow_divert_pcb * fd_cb)616 flow_divert_check_no_cellular(struct flow_divert_pcb *fd_cb)
617 {
618 	struct inpcb *inp = sotoinpcb(fd_cb->so);
619 	if (INP_NO_CELLULAR(inp)) {
620 		struct ifnet *ifp = NULL;
621 		if (inp->inp_vflag & INP_IPV4) {
622 			ifp = inp->inp_last_outifp;
623 		} else if (inp->inp_vflag & INP_IPV6) {
624 			ifp = inp->in6p_last_outifp;
625 		}
626 		if (ifp != NULL && IFNET_IS_CELLULAR(ifp)) {
627 			FDLOG0(LOG_ERR, fd_cb, "Cellular is denied");
628 			return EHOSTUNREACH;
629 		}
630 	}
631 	return 0;
632 }
633 
634 static errno_t
flow_divert_check_no_expensive(struct flow_divert_pcb * fd_cb)635 flow_divert_check_no_expensive(struct flow_divert_pcb *fd_cb)
636 {
637 	struct inpcb *inp = sotoinpcb(fd_cb->so);
638 	if (INP_NO_EXPENSIVE(inp)) {
639 		struct ifnet *ifp = NULL;
640 		if (inp->inp_vflag & INP_IPV4) {
641 			ifp = inp->inp_last_outifp;
642 		} else if (inp->inp_vflag & INP_IPV6) {
643 			ifp = inp->in6p_last_outifp;
644 		}
645 		if (ifp != NULL && IFNET_IS_EXPENSIVE(ifp)) {
646 			FDLOG0(LOG_ERR, fd_cb, "Expensive is denied");
647 			return EHOSTUNREACH;
648 		}
649 	}
650 	return 0;
651 }
652 
653 static errno_t
flow_divert_check_no_constrained(struct flow_divert_pcb * fd_cb)654 flow_divert_check_no_constrained(struct flow_divert_pcb *fd_cb)
655 {
656 	struct inpcb *inp = sotoinpcb(fd_cb->so);
657 	if (INP_NO_CONSTRAINED(inp)) {
658 		struct ifnet *ifp = NULL;
659 		if (inp->inp_vflag & INP_IPV4) {
660 			ifp = inp->inp_last_outifp;
661 		} else if (inp->inp_vflag & INP_IPV6) {
662 			ifp = inp->in6p_last_outifp;
663 		}
664 		if (ifp != NULL && IFNET_IS_CONSTRAINED(ifp)) {
665 			FDLOG0(LOG_ERR, fd_cb, "Constrained is denied");
666 			return EHOSTUNREACH;
667 		}
668 	}
669 	return 0;
670 }
671 
672 static void
flow_divert_update_closed_state(struct flow_divert_pcb * fd_cb,int how,Boolean tunnel)673 flow_divert_update_closed_state(struct flow_divert_pcb *fd_cb, int how, Boolean tunnel)
674 {
675 	if (how != SHUT_RD) {
676 		fd_cb->flags |= FLOW_DIVERT_WRITE_CLOSED;
677 		if (tunnel || !(fd_cb->flags & FLOW_DIVERT_CONNECT_STARTED)) {
678 			fd_cb->flags |= FLOW_DIVERT_TUNNEL_WR_CLOSED;
679 			/* If the tunnel is not accepting writes any more, then flush the send buffer */
680 			sbflush(&fd_cb->so->so_snd);
681 		}
682 	}
683 	if (how != SHUT_WR) {
684 		fd_cb->flags |= FLOW_DIVERT_READ_CLOSED;
685 		if (tunnel || !(fd_cb->flags & FLOW_DIVERT_CONNECT_STARTED)) {
686 			fd_cb->flags |= FLOW_DIVERT_TUNNEL_RD_CLOSED;
687 		}
688 	}
689 }
690 
691 static uint16_t
trie_node_alloc(struct flow_divert_trie * trie)692 trie_node_alloc(struct flow_divert_trie *trie)
693 {
694 	if (trie->nodes_free_next < trie->nodes_count) {
695 		uint16_t node_idx = trie->nodes_free_next++;
696 		TRIE_NODE(trie, node_idx).child_map = NULL_TRIE_IDX;
697 		return node_idx;
698 	} else {
699 		return NULL_TRIE_IDX;
700 	}
701 }
702 
703 static uint16_t
trie_child_map_alloc(struct flow_divert_trie * trie)704 trie_child_map_alloc(struct flow_divert_trie *trie)
705 {
706 	if (trie->child_maps_free_next < trie->child_maps_count) {
707 		return trie->child_maps_free_next++;
708 	} else {
709 		return NULL_TRIE_IDX;
710 	}
711 }
712 
713 static uint16_t
trie_bytes_move(struct flow_divert_trie * trie,uint16_t bytes_idx,size_t bytes_size)714 trie_bytes_move(struct flow_divert_trie *trie, uint16_t bytes_idx, size_t bytes_size)
715 {
716 	uint16_t start = trie->bytes_free_next;
717 	if (start + bytes_size <= trie->bytes_count) {
718 		if (start != bytes_idx) {
719 			memmove(&TRIE_BYTE(trie, start), &TRIE_BYTE(trie, bytes_idx), bytes_size);
720 		}
721 		trie->bytes_free_next += bytes_size;
722 		return start;
723 	} else {
724 		return NULL_TRIE_IDX;
725 	}
726 }
727 
728 static uint16_t
flow_divert_trie_insert(struct flow_divert_trie * trie,uint16_t string_start,size_t string_len)729 flow_divert_trie_insert(struct flow_divert_trie *trie, uint16_t string_start, size_t string_len)
730 {
731 	uint16_t current = trie->root;
732 	uint16_t child = trie->root;
733 	uint16_t string_end = string_start + (uint16_t)string_len;
734 	uint16_t string_idx = string_start;
735 	uint16_t string_remainder = (uint16_t)string_len;
736 
737 	while (child != NULL_TRIE_IDX) {
738 		uint16_t parent = current;
739 		uint16_t node_idx;
740 		uint16_t current_end;
741 
742 		current = child;
743 		child = NULL_TRIE_IDX;
744 
745 		current_end = TRIE_NODE(trie, current).start + TRIE_NODE(trie, current).length;
746 
747 		for (node_idx = TRIE_NODE(trie, current).start;
748 		    node_idx < current_end &&
749 		    string_idx < string_end &&
750 		    TRIE_BYTE(trie, node_idx) == TRIE_BYTE(trie, string_idx);
751 		    node_idx++, string_idx++) {
752 			;
753 		}
754 
755 		string_remainder = string_end - string_idx;
756 
757 		if (node_idx < (TRIE_NODE(trie, current).start + TRIE_NODE(trie, current).length)) {
758 			/*
759 			 * We did not reach the end of the current node's string.
760 			 * We need to split the current node into two:
761 			 *   1. A new node that contains the prefix of the node that matches
762 			 *      the prefix of the string being inserted.
763 			 *   2. The current node modified to point to the remainder
764 			 *      of the current node's string.
765 			 */
766 			uint16_t prefix = trie_node_alloc(trie);
767 			if (prefix == NULL_TRIE_IDX) {
768 				FDLOG0(LOG_ERR, &nil_pcb, "Ran out of trie nodes while splitting an existing node");
769 				return NULL_TRIE_IDX;
770 			}
771 
772 			/*
773 			 * Prefix points to the portion of the current nodes's string that has matched
774 			 * the input string thus far.
775 			 */
776 			TRIE_NODE(trie, prefix).start = TRIE_NODE(trie, current).start;
777 			TRIE_NODE(trie, prefix).length = (node_idx - TRIE_NODE(trie, current).start);
778 
779 			/*
780 			 * Prefix has the current node as the child corresponding to the first byte
781 			 * after the split.
782 			 */
783 			TRIE_NODE(trie, prefix).child_map = trie_child_map_alloc(trie);
784 			if (TRIE_NODE(trie, prefix).child_map == NULL_TRIE_IDX) {
785 				FDLOG0(LOG_ERR, &nil_pcb, "Ran out of child maps while splitting an existing node");
786 				return NULL_TRIE_IDX;
787 			}
788 			TRIE_CHILD(trie, prefix, TRIE_BYTE(trie, node_idx)) = current;
789 
790 			/* Parent has the prefix as the child correspoding to the first byte in the prefix */
791 			TRIE_CHILD(trie, parent, TRIE_BYTE(trie, TRIE_NODE(trie, prefix).start)) = prefix;
792 
793 			/* Current node is adjusted to point to the remainder */
794 			TRIE_NODE(trie, current).start = node_idx;
795 			TRIE_NODE(trie, current).length -= TRIE_NODE(trie, prefix).length;
796 
797 			/* We want to insert the new leaf (if any) as a child of the prefix */
798 			current = prefix;
799 		}
800 
801 		if (string_remainder > 0) {
802 			/*
803 			 * We still have bytes in the string that have not been matched yet.
804 			 * If the current node has children, iterate to the child corresponding
805 			 * to the next byte in the string.
806 			 */
807 			if (TRIE_NODE(trie, current).child_map != NULL_TRIE_IDX) {
808 				child = TRIE_CHILD(trie, current, TRIE_BYTE(trie, string_idx));
809 			}
810 		}
811 	} /* while (child != NULL_TRIE_IDX) */
812 
813 	if (string_remainder > 0) {
814 		/* Add a new leaf containing the remainder of the string */
815 		uint16_t leaf = trie_node_alloc(trie);
816 		if (leaf == NULL_TRIE_IDX) {
817 			FDLOG0(LOG_ERR, &nil_pcb, "Ran out of trie nodes while inserting a new leaf");
818 			return NULL_TRIE_IDX;
819 		}
820 
821 		TRIE_NODE(trie, leaf).start = trie_bytes_move(trie, string_idx, string_remainder);
822 		if (TRIE_NODE(trie, leaf).start == NULL_TRIE_IDX) {
823 			FDLOG0(LOG_ERR, &nil_pcb, "Ran out of bytes while inserting a new leaf");
824 			return NULL_TRIE_IDX;
825 		}
826 		TRIE_NODE(trie, leaf).length = string_remainder;
827 
828 		/* Set the new leaf as the child of the current node */
829 		if (TRIE_NODE(trie, current).child_map == NULL_TRIE_IDX) {
830 			TRIE_NODE(trie, current).child_map = trie_child_map_alloc(trie);
831 			if (TRIE_NODE(trie, current).child_map == NULL_TRIE_IDX) {
832 				FDLOG0(LOG_ERR, &nil_pcb, "Ran out of child maps while inserting a new leaf");
833 				return NULL_TRIE_IDX;
834 			}
835 		}
836 		TRIE_CHILD(trie, current, TRIE_BYTE(trie, TRIE_NODE(trie, leaf).start)) = leaf;
837 		current = leaf;
838 	} /* else duplicate or this string is a prefix of one of the existing strings */
839 
840 	return current;
841 }
842 
843 #define APPLE_WEBCLIP_ID_PREFIX "com.apple.webapp"
844 static uint16_t
flow_divert_trie_search(struct flow_divert_trie * trie,const uint8_t * string_bytes)845 flow_divert_trie_search(struct flow_divert_trie *trie, const uint8_t *string_bytes)
846 {
847 	uint16_t current = trie->root;
848 	uint16_t string_idx = 0;
849 
850 	while (current != NULL_TRIE_IDX) {
851 		uint16_t next = NULL_TRIE_IDX;
852 		uint16_t node_end = TRIE_NODE(trie, current).start + TRIE_NODE(trie, current).length;
853 		uint16_t node_idx;
854 
855 		for (node_idx = TRIE_NODE(trie, current).start;
856 		    node_idx < node_end && string_bytes[string_idx] != '\0' && string_bytes[string_idx] == TRIE_BYTE(trie, node_idx);
857 		    node_idx++, string_idx++) {
858 			;
859 		}
860 
861 		if (node_idx == node_end) {
862 			if (string_bytes[string_idx] == '\0') {
863 				return current; /* Got an exact match */
864 			} else if (string_idx == strlen(APPLE_WEBCLIP_ID_PREFIX) &&
865 			    0 == strncmp((const char *)string_bytes, APPLE_WEBCLIP_ID_PREFIX, string_idx)) {
866 				return current; /* Got an apple webclip id prefix match */
867 			} else if (TRIE_NODE(trie, current).child_map != NULL_TRIE_IDX) {
868 				next = TRIE_CHILD(trie, current, string_bytes[string_idx]);
869 			}
870 		}
871 		current = next;
872 	}
873 
874 	return NULL_TRIE_IDX;
875 }
876 
877 struct uuid_search_info {
878 	uuid_t target_uuid;
879 	char *found_signing_id;
880 	boolean_t found_multiple_signing_ids;
881 	proc_t found_proc;
882 };
883 
884 static int
flow_divert_find_proc_by_uuid_callout(proc_t p,void * arg)885 flow_divert_find_proc_by_uuid_callout(proc_t p, void *arg)
886 {
887 	struct uuid_search_info *info = (struct uuid_search_info *)arg;
888 	int result = PROC_RETURNED_DONE; /* By default, we didn't find the process */
889 
890 	if (info->found_signing_id != NULL) {
891 		if (!info->found_multiple_signing_ids) {
892 			/* All processes that were found had the same signing identifier, so just claim this first one and be done. */
893 			info->found_proc = p;
894 			result = PROC_CLAIMED_DONE;
895 		} else {
896 			uuid_string_t uuid_str;
897 			uuid_unparse(info->target_uuid, uuid_str);
898 			FDLOG(LOG_WARNING, &nil_pcb, "Found multiple processes with UUID %s with different signing identifiers", uuid_str);
899 		}
900 		kfree_data(info->found_signing_id, strlen(info->found_signing_id) + 1);
901 		info->found_signing_id = NULL;
902 	}
903 
904 	if (result == PROC_RETURNED_DONE) {
905 		uuid_string_t uuid_str;
906 		uuid_unparse(info->target_uuid, uuid_str);
907 		FDLOG(LOG_WARNING, &nil_pcb, "Failed to find a process with UUID %s", uuid_str);
908 	}
909 
910 	return result;
911 }
912 
913 static int
flow_divert_find_proc_by_uuid_filter(proc_t p,void * arg)914 flow_divert_find_proc_by_uuid_filter(proc_t p, void *arg)
915 {
916 	struct uuid_search_info *info = (struct uuid_search_info *)arg;
917 	int include = 0;
918 
919 	if (info->found_multiple_signing_ids) {
920 		return include;
921 	}
922 
923 	include = (uuid_compare(proc_executableuuid_addr(p), info->target_uuid) == 0);
924 	if (include) {
925 		const char *signing_id = cs_identity_get(p);
926 		if (signing_id != NULL) {
927 			FDLOG(LOG_INFO, &nil_pcb, "Found process %d with signing identifier %s", proc_getpid(p), signing_id);
928 			size_t signing_id_size = strlen(signing_id) + 1;
929 			if (info->found_signing_id == NULL) {
930 				info->found_signing_id = kalloc_data(signing_id_size, Z_WAITOK);
931 				memcpy(info->found_signing_id, signing_id, signing_id_size);
932 			} else if (memcmp(signing_id, info->found_signing_id, signing_id_size)) {
933 				info->found_multiple_signing_ids = TRUE;
934 			}
935 		} else {
936 			info->found_multiple_signing_ids = TRUE;
937 		}
938 		include = !info->found_multiple_signing_ids;
939 	}
940 
941 	return include;
942 }
943 
944 static proc_t
flow_divert_find_proc_by_uuid(uuid_t uuid)945 flow_divert_find_proc_by_uuid(uuid_t uuid)
946 {
947 	struct uuid_search_info info;
948 
949 	if (LOG_INFO <= nil_pcb.log_level) {
950 		uuid_string_t uuid_str;
951 		uuid_unparse(uuid, uuid_str);
952 		FDLOG(LOG_INFO, &nil_pcb, "Looking for process with UUID %s", uuid_str);
953 	}
954 
955 	memset(&info, 0, sizeof(info));
956 	info.found_proc = PROC_NULL;
957 	uuid_copy(info.target_uuid, uuid);
958 
959 	proc_iterate(PROC_ALLPROCLIST, flow_divert_find_proc_by_uuid_callout, &info, flow_divert_find_proc_by_uuid_filter, &info);
960 
961 	return info.found_proc;
962 }
963 
964 static int
flow_divert_add_proc_info(struct flow_divert_pcb * fd_cb,proc_t proc,const char * signing_id,mbuf_t connect_packet,bool is_effective)965 flow_divert_add_proc_info(struct flow_divert_pcb *fd_cb, proc_t proc, const char *signing_id, mbuf_t connect_packet, bool is_effective)
966 {
967 	int error = 0;
968 	uint8_t *cdhash = NULL;
969 	audit_token_t audit_token = {};
970 	const char *proc_cs_id = signing_id;
971 
972 	proc_lock(proc);
973 
974 	if (proc_cs_id == NULL) {
975 		if (proc_getcsflags(proc) & (CS_VALID | CS_DEBUGGED)) {
976 			proc_cs_id = cs_identity_get(proc);
977 		} else {
978 			FDLOG0(LOG_ERR, fd_cb, "Signature of proc is invalid");
979 		}
980 	}
981 
982 	if (is_effective) {
983 		lck_rw_lock_shared(&fd_cb->group->lck);
984 		if (!(fd_cb->group->flags & FLOW_DIVERT_GROUP_FLAG_NO_APP_MAP)) {
985 			if (proc_cs_id != NULL) {
986 				uint16_t result = flow_divert_trie_search(&fd_cb->group->signing_id_trie, (const uint8_t *)proc_cs_id);
987 				if (result == NULL_TRIE_IDX) {
988 					FDLOG(LOG_WARNING, fd_cb, "%s did not match", proc_cs_id);
989 					error = EPERM;
990 				} else {
991 					FDLOG(LOG_INFO, fd_cb, "%s matched", proc_cs_id);
992 				}
993 			} else {
994 				error = EPERM;
995 			}
996 		}
997 		lck_rw_done(&fd_cb->group->lck);
998 	}
999 
1000 	if (error != 0) {
1001 		goto done;
1002 	}
1003 
1004 	/*
1005 	 * If signing_id is not NULL then it came from the flow divert token and will be added
1006 	 * as part of the token, so there is no need to add it here.
1007 	 */
1008 	if (signing_id == NULL && proc_cs_id != NULL) {
1009 		error = flow_divert_packet_append_tlv(connect_packet,
1010 		    (is_effective ? FLOW_DIVERT_TLV_SIGNING_ID : FLOW_DIVERT_TLV_APP_REAL_SIGNING_ID),
1011 		    (uint32_t)strlen(proc_cs_id),
1012 		    proc_cs_id);
1013 		if (error != 0) {
1014 			FDLOG(LOG_ERR, fd_cb, "failed to append the signing ID: %d", error);
1015 			goto done;
1016 		}
1017 	}
1018 
1019 	cdhash = cs_get_cdhash(proc);
1020 	if (cdhash != NULL) {
1021 		error = flow_divert_packet_append_tlv(connect_packet,
1022 		    (is_effective ? FLOW_DIVERT_TLV_CDHASH : FLOW_DIVERT_TLV_APP_REAL_CDHASH),
1023 		    SHA1_RESULTLEN,
1024 		    cdhash);
1025 		if (error) {
1026 			FDLOG(LOG_ERR, fd_cb, "failed to append the cdhash: %d", error);
1027 			goto done;
1028 		}
1029 	} else {
1030 		FDLOG0(LOG_ERR, fd_cb, "failed to get the cdhash");
1031 	}
1032 
1033 	task_t task = proc_task(proc);
1034 	if (task != TASK_NULL) {
1035 		mach_msg_type_number_t count = TASK_AUDIT_TOKEN_COUNT;
1036 		kern_return_t rc = task_info(task, TASK_AUDIT_TOKEN, (task_info_t)&audit_token, &count);
1037 		if (rc == KERN_SUCCESS) {
1038 			int append_error = flow_divert_packet_append_tlv(connect_packet,
1039 			    (is_effective ? FLOW_DIVERT_TLV_APP_AUDIT_TOKEN : FLOW_DIVERT_TLV_APP_REAL_AUDIT_TOKEN),
1040 			    sizeof(audit_token_t),
1041 			    &audit_token);
1042 			if (append_error) {
1043 				FDLOG(LOG_ERR, fd_cb, "failed to append app audit token: %d", append_error);
1044 			}
1045 		}
1046 	}
1047 
1048 done:
1049 	proc_unlock(proc);
1050 
1051 	return error;
1052 }
1053 
1054 static int
flow_divert_add_all_proc_info(struct flow_divert_pcb * fd_cb,struct socket * so,proc_t proc,const char * signing_id,mbuf_t connect_packet)1055 flow_divert_add_all_proc_info(struct flow_divert_pcb *fd_cb, struct socket *so, proc_t proc, const char *signing_id, mbuf_t connect_packet)
1056 {
1057 	int error = 0;
1058 	proc_t effective_proc = PROC_NULL;
1059 	proc_t responsible_proc = PROC_NULL;
1060 	proc_t real_proc = proc_find(so->last_pid);
1061 	bool release_real_proc = true;
1062 
1063 	proc_t src_proc = PROC_NULL;
1064 	proc_t real_src_proc = PROC_NULL;
1065 
1066 	if (real_proc == PROC_NULL) {
1067 		FDLOG(LOG_ERR, fd_cb, "failed to find the real proc record for %d", so->last_pid);
1068 		release_real_proc = false;
1069 		real_proc = proc;
1070 		if (real_proc == PROC_NULL) {
1071 			real_proc = current_proc();
1072 		}
1073 	}
1074 
1075 	if (so->so_flags & SOF_DELEGATED) {
1076 		if (proc_getpid(real_proc) != so->e_pid) {
1077 			effective_proc = proc_find(so->e_pid);
1078 		} else if (uuid_compare(proc_executableuuid_addr(real_proc), so->e_uuid)) {
1079 			effective_proc = flow_divert_find_proc_by_uuid(so->e_uuid);
1080 		}
1081 	}
1082 
1083 #if defined(XNU_TARGET_OS_OSX)
1084 	lck_rw_lock_shared(&fd_cb->group->lck);
1085 	if (!(fd_cb->group->flags & FLOW_DIVERT_GROUP_FLAG_NO_APP_MAP)) {
1086 		if (so->so_rpid > 0) {
1087 			responsible_proc = proc_find(so->so_rpid);
1088 		}
1089 	}
1090 	lck_rw_done(&fd_cb->group->lck);
1091 #endif
1092 
1093 	real_src_proc = real_proc;
1094 
1095 	if (responsible_proc != PROC_NULL) {
1096 		src_proc = responsible_proc;
1097 		if (effective_proc != NULL) {
1098 			real_src_proc = effective_proc;
1099 		}
1100 	} else if (effective_proc != PROC_NULL) {
1101 		src_proc = effective_proc;
1102 	} else {
1103 		src_proc = real_proc;
1104 	}
1105 
1106 	error = flow_divert_add_proc_info(fd_cb, src_proc, signing_id, connect_packet, true);
1107 	if (error != 0) {
1108 		goto done;
1109 	}
1110 
1111 	if (real_src_proc != NULL && real_src_proc != src_proc) {
1112 		error = flow_divert_add_proc_info(fd_cb, real_src_proc, NULL, connect_packet, false);
1113 		if (error != 0) {
1114 			goto done;
1115 		}
1116 	}
1117 
1118 done:
1119 	if (responsible_proc != PROC_NULL) {
1120 		proc_rele(responsible_proc);
1121 	}
1122 
1123 	if (effective_proc != PROC_NULL) {
1124 		proc_rele(effective_proc);
1125 	}
1126 
1127 	if (real_proc != PROC_NULL && release_real_proc) {
1128 		proc_rele(real_proc);
1129 	}
1130 
1131 	return error;
1132 }
1133 
1134 static int
flow_divert_send_packet(struct flow_divert_pcb * fd_cb,mbuf_t packet,Boolean enqueue)1135 flow_divert_send_packet(struct flow_divert_pcb *fd_cb, mbuf_t packet, Boolean enqueue)
1136 {
1137 	int             error;
1138 
1139 	if (fd_cb->group == NULL) {
1140 		FDLOG0(LOG_INFO, fd_cb, "no provider, cannot send packet");
1141 		flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
1142 		flow_divert_disconnect_socket(fd_cb->so, !(fd_cb->flags & FLOW_DIVERT_IMPLICIT_CONNECT));
1143 		if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
1144 			error = ECONNABORTED;
1145 		} else {
1146 			error = EHOSTUNREACH;
1147 		}
1148 		fd_cb->so->so_error = (uint16_t)error;
1149 		return error;
1150 	}
1151 
1152 	lck_rw_lock_shared(&fd_cb->group->lck);
1153 
1154 	if (MBUFQ_EMPTY(&fd_cb->group->send_queue)) {
1155 		error = ctl_enqueuembuf(g_flow_divert_kctl_ref, fd_cb->group->ctl_unit, packet, CTL_DATA_EOR);
1156 	} else {
1157 		error = ENOBUFS;
1158 	}
1159 
1160 	if (error == ENOBUFS) {
1161 		if (enqueue) {
1162 			if (!lck_rw_lock_shared_to_exclusive(&fd_cb->group->lck)) {
1163 				lck_rw_lock_exclusive(&fd_cb->group->lck);
1164 			}
1165 			MBUFQ_ENQUEUE(&fd_cb->group->send_queue, packet);
1166 			error = 0;
1167 		}
1168 		OSTestAndSet(GROUP_BIT_CTL_ENQUEUE_BLOCKED, &fd_cb->group->atomic_bits);
1169 	}
1170 
1171 	lck_rw_done(&fd_cb->group->lck);
1172 
1173 	return error;
1174 }
1175 
1176 static int
flow_divert_create_connect_packet(struct flow_divert_pcb * fd_cb,struct sockaddr * to,struct socket * so,proc_t p,mbuf_t * out_connect_packet)1177 flow_divert_create_connect_packet(struct flow_divert_pcb *fd_cb, struct sockaddr *to, struct socket *so, proc_t p, mbuf_t *out_connect_packet)
1178 {
1179 	int                     error                   = 0;
1180 	int                     flow_type               = 0;
1181 	char                    *signing_id = NULL;
1182 	uint32_t                sid_size = 0;
1183 	mbuf_t                  connect_packet = NULL;
1184 	cfil_sock_id_t          cfil_sock_id            = CFIL_SOCK_ID_NONE;
1185 	const void              *cfil_id                = NULL;
1186 	size_t                  cfil_id_size            = 0;
1187 	struct inpcb            *inp = sotoinpcb(so);
1188 	struct ifnet *ifp = NULL;
1189 	uint32_t flags = 0;
1190 
1191 	error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_CONNECT, &connect_packet);
1192 	if (error) {
1193 		goto done;
1194 	}
1195 
1196 	if (fd_cb->connect_token != NULL && (fd_cb->flags & FLOW_DIVERT_HAS_HMAC)) {
1197 		int find_error = flow_divert_packet_get_tlv(fd_cb->connect_token, 0, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
1198 		if (find_error == 0 && sid_size > 0) {
1199 			signing_id = kalloc_data(sid_size + 1, Z_WAITOK | Z_ZERO);
1200 			if (signing_id != NULL) {
1201 				flow_divert_packet_get_tlv(fd_cb->connect_token, 0, FLOW_DIVERT_TLV_SIGNING_ID, sid_size, signing_id, NULL);
1202 				FDLOG(LOG_INFO, fd_cb, "Got %s from token", signing_id);
1203 			}
1204 		}
1205 	}
1206 
1207 	socket_unlock(so, 0);
1208 
1209 	error = flow_divert_add_all_proc_info(fd_cb, so, p, signing_id, connect_packet);
1210 
1211 	socket_lock(so, 0);
1212 
1213 	if (signing_id != NULL) {
1214 		kfree_data(signing_id, sid_size + 1);
1215 	}
1216 
1217 	if (error) {
1218 		FDLOG(LOG_ERR, fd_cb, "Failed to add source proc info: %d", error);
1219 		goto done;
1220 	}
1221 
1222 	error = flow_divert_packet_append_tlv(connect_packet,
1223 	    FLOW_DIVERT_TLV_TRAFFIC_CLASS,
1224 	    sizeof(fd_cb->so->so_traffic_class),
1225 	    &fd_cb->so->so_traffic_class);
1226 	if (error) {
1227 		goto done;
1228 	}
1229 
1230 	if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
1231 		flow_type = FLOW_DIVERT_FLOW_TYPE_TCP;
1232 	} else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
1233 		flow_type = FLOW_DIVERT_FLOW_TYPE_UDP;
1234 	} else {
1235 		error = EINVAL;
1236 		goto done;
1237 	}
1238 	error = flow_divert_packet_append_tlv(connect_packet,
1239 	    FLOW_DIVERT_TLV_FLOW_TYPE,
1240 	    sizeof(flow_type),
1241 	    &flow_type);
1242 
1243 	if (error) {
1244 		goto done;
1245 	}
1246 
1247 	if (fd_cb->connect_token != NULL) {
1248 		unsigned int token_len = m_length(fd_cb->connect_token);
1249 		mbuf_concatenate(connect_packet, fd_cb->connect_token);
1250 		mbuf_pkthdr_adjustlen(connect_packet, token_len);
1251 		fd_cb->connect_token = NULL;
1252 	} else {
1253 		error = flow_divert_append_target_endpoint_tlv(connect_packet, to);
1254 		if (error) {
1255 			goto done;
1256 		}
1257 
1258 		if (inp->inp_necp_attributes.inp_domain != NULL) {
1259 			size_t domain_length = strlen(inp->inp_necp_attributes.inp_domain);
1260 			if (domain_length > 0 && domain_length <= FLOW_DIVERT_MAX_NAME_SIZE) {
1261 				error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_HOSTNAME, (uint32_t)domain_length, inp->inp_necp_attributes.inp_domain);
1262 			}
1263 		}
1264 	}
1265 
1266 	if (fd_cb->local_endpoint.sa.sa_family == AF_INET || fd_cb->local_endpoint.sa.sa_family == AF_INET6) {
1267 		error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_LOCAL_ADDR, fd_cb->local_endpoint.sa.sa_len, &(fd_cb->local_endpoint.sa));
1268 		if (error) {
1269 			goto done;
1270 		}
1271 	}
1272 
1273 	if (inp->inp_vflag & INP_IPV4) {
1274 		ifp = inp->inp_last_outifp;
1275 	} else if (inp->inp_vflag & INP_IPV6) {
1276 		ifp = inp->in6p_last_outifp;
1277 	}
1278 	if (ifp != NULL) {
1279 		uint32_t flow_if_index = ifp->if_index;
1280 		error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_OUT_IF_INDEX,
1281 		    sizeof(flow_if_index), &flow_if_index);
1282 		if (error) {
1283 			goto done;
1284 		}
1285 	}
1286 
1287 	if (so->so_flags1 & SOF1_DATA_IDEMPOTENT) {
1288 		flags |= FLOW_DIVERT_TOKEN_FLAG_TFO;
1289 	}
1290 
1291 	if ((inp->inp_flags & INP_BOUND_IF) ||
1292 	    ((inp->inp_vflag & INP_IPV6) && !IN6_IS_ADDR_UNSPECIFIED(&inp->in6p_laddr)) ||
1293 	    ((inp->inp_vflag & INP_IPV4) && inp->inp_laddr.s_addr != INADDR_ANY)) {
1294 		flags |= FLOW_DIVERT_TOKEN_FLAG_BOUND;
1295 	}
1296 
1297 	if (flags != 0) {
1298 		error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_FLAGS, sizeof(flags), &flags);
1299 		if (error) {
1300 			goto done;
1301 		}
1302 	}
1303 
1304 	if (SOCK_TYPE(so) == SOCK_DGRAM) {
1305 		cfil_sock_id = cfil_sock_id_from_datagram_socket(so, NULL, to);
1306 	} else {
1307 		cfil_sock_id = cfil_sock_id_from_socket(so);
1308 	}
1309 
1310 	if (cfil_sock_id != CFIL_SOCK_ID_NONE) {
1311 		cfil_id = &cfil_sock_id;
1312 		cfil_id_size = sizeof(cfil_sock_id);
1313 	} else if (so->so_flags1 & SOF1_CONTENT_FILTER_SKIP) {
1314 		cfil_id = &inp->necp_client_uuid;
1315 		cfil_id_size = sizeof(inp->necp_client_uuid);
1316 	}
1317 
1318 	if (cfil_id != NULL && cfil_id_size > 0 && cfil_id_size <= sizeof(uuid_t)) {
1319 		error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_CFIL_ID, (uint32_t)cfil_id_size, cfil_id);
1320 		if (error) {
1321 			goto done;
1322 		}
1323 	}
1324 
1325 done:
1326 	if (!error) {
1327 		*out_connect_packet = connect_packet;
1328 	} else if (connect_packet != NULL) {
1329 		mbuf_freem(connect_packet);
1330 	}
1331 
1332 	return error;
1333 }
1334 
1335 static int
flow_divert_send_connect_packet(struct flow_divert_pcb * fd_cb)1336 flow_divert_send_connect_packet(struct flow_divert_pcb *fd_cb)
1337 {
1338 	int error = 0;
1339 	mbuf_t connect_packet = fd_cb->connect_packet;
1340 	mbuf_t saved_connect_packet = NULL;
1341 
1342 	if (connect_packet != NULL) {
1343 		error = mbuf_copym(connect_packet, 0, mbuf_pkthdr_len(connect_packet), MBUF_DONTWAIT, &saved_connect_packet);
1344 		if (error) {
1345 			FDLOG0(LOG_ERR, fd_cb, "Failed to copy the connect packet");
1346 			goto done;
1347 		}
1348 
1349 		error = flow_divert_send_packet(fd_cb, connect_packet, TRUE);
1350 		if (error) {
1351 			goto done;
1352 		}
1353 
1354 		fd_cb->connect_packet = saved_connect_packet;
1355 		saved_connect_packet = NULL;
1356 	} else {
1357 		error = ENOENT;
1358 	}
1359 done:
1360 	if (saved_connect_packet != NULL) {
1361 		mbuf_freem(saved_connect_packet);
1362 	}
1363 
1364 	return error;
1365 }
1366 
1367 static int
flow_divert_send_connect_result(struct flow_divert_pcb * fd_cb)1368 flow_divert_send_connect_result(struct flow_divert_pcb *fd_cb)
1369 {
1370 	int             error                   = 0;
1371 	mbuf_t  packet                  = NULL;
1372 	int             rbuff_space             = 0;
1373 
1374 	error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_CONNECT_RESULT, &packet);
1375 	if (error) {
1376 		FDLOG(LOG_ERR, fd_cb, "failed to create a connect result packet: %d", error);
1377 		goto done;
1378 	}
1379 
1380 	rbuff_space = fd_cb->so->so_rcv.sb_hiwat;
1381 	if (rbuff_space < 0) {
1382 		rbuff_space = 0;
1383 	}
1384 	rbuff_space = htonl(rbuff_space);
1385 	error = flow_divert_packet_append_tlv(packet,
1386 	    FLOW_DIVERT_TLV_SPACE_AVAILABLE,
1387 	    sizeof(rbuff_space),
1388 	    &rbuff_space);
1389 	if (error) {
1390 		goto done;
1391 	}
1392 
1393 	if (fd_cb->local_endpoint.sa.sa_family == AF_INET || fd_cb->local_endpoint.sa.sa_family == AF_INET6) {
1394 		error = flow_divert_packet_append_tlv(packet, FLOW_DIVERT_TLV_LOCAL_ADDR, fd_cb->local_endpoint.sa.sa_len, &(fd_cb->local_endpoint.sa));
1395 		if (error) {
1396 			goto done;
1397 		}
1398 	}
1399 
1400 	error = flow_divert_send_packet(fd_cb, packet, TRUE);
1401 	if (error) {
1402 		goto done;
1403 	}
1404 
1405 done:
1406 	if (error && packet != NULL) {
1407 		mbuf_freem(packet);
1408 	}
1409 
1410 	return error;
1411 }
1412 
1413 static int
flow_divert_send_close(struct flow_divert_pcb * fd_cb,int how)1414 flow_divert_send_close(struct flow_divert_pcb *fd_cb, int how)
1415 {
1416 	int             error   = 0;
1417 	mbuf_t  packet  = NULL;
1418 	uint32_t        zero    = 0;
1419 
1420 	error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_CLOSE, &packet);
1421 	if (error) {
1422 		FDLOG(LOG_ERR, fd_cb, "failed to create a close packet: %d", error);
1423 		goto done;
1424 	}
1425 
1426 	error = flow_divert_packet_append_tlv(packet, FLOW_DIVERT_TLV_ERROR_CODE, sizeof(zero), &zero);
1427 	if (error) {
1428 		FDLOG(LOG_ERR, fd_cb, "failed to add the error code TLV: %d", error);
1429 		goto done;
1430 	}
1431 
1432 	how = htonl(how);
1433 	error = flow_divert_packet_append_tlv(packet, FLOW_DIVERT_TLV_HOW, sizeof(how), &how);
1434 	if (error) {
1435 		FDLOG(LOG_ERR, fd_cb, "failed to add the how flag: %d", error);
1436 		goto done;
1437 	}
1438 
1439 	error = flow_divert_send_packet(fd_cb, packet, TRUE);
1440 	if (error) {
1441 		goto done;
1442 	}
1443 
1444 done:
1445 	if (error && packet != NULL) {
1446 		mbuf_free(packet);
1447 	}
1448 
1449 	return error;
1450 }
1451 
1452 static int
flow_divert_tunnel_how_closed(struct flow_divert_pcb * fd_cb)1453 flow_divert_tunnel_how_closed(struct flow_divert_pcb *fd_cb)
1454 {
1455 	if ((fd_cb->flags & (FLOW_DIVERT_TUNNEL_RD_CLOSED | FLOW_DIVERT_TUNNEL_WR_CLOSED)) ==
1456 	    (FLOW_DIVERT_TUNNEL_RD_CLOSED | FLOW_DIVERT_TUNNEL_WR_CLOSED)) {
1457 		return SHUT_RDWR;
1458 	} else if (fd_cb->flags & FLOW_DIVERT_TUNNEL_RD_CLOSED) {
1459 		return SHUT_RD;
1460 	} else if (fd_cb->flags & FLOW_DIVERT_TUNNEL_WR_CLOSED) {
1461 		return SHUT_WR;
1462 	}
1463 
1464 	return -1;
1465 }
1466 
1467 /*
1468  * Determine what close messages if any need to be sent to the tunnel. Returns TRUE if the tunnel is closed for both reads and
1469  * writes. Returns FALSE otherwise.
1470  */
1471 static void
flow_divert_send_close_if_needed(struct flow_divert_pcb * fd_cb)1472 flow_divert_send_close_if_needed(struct flow_divert_pcb *fd_cb)
1473 {
1474 	int             how             = -1;
1475 
1476 	/* Do not send any close messages if there is still data in the send buffer */
1477 	if (fd_cb->so->so_snd.sb_cc == 0) {
1478 		if ((fd_cb->flags & (FLOW_DIVERT_READ_CLOSED | FLOW_DIVERT_TUNNEL_RD_CLOSED)) == FLOW_DIVERT_READ_CLOSED) {
1479 			/* Socket closed reads, but tunnel did not. Tell tunnel to close reads */
1480 			how = SHUT_RD;
1481 		}
1482 		if ((fd_cb->flags & (FLOW_DIVERT_WRITE_CLOSED | FLOW_DIVERT_TUNNEL_WR_CLOSED)) == FLOW_DIVERT_WRITE_CLOSED) {
1483 			/* Socket closed writes, but tunnel did not. Tell tunnel to close writes */
1484 			if (how == SHUT_RD) {
1485 				how = SHUT_RDWR;
1486 			} else {
1487 				how = SHUT_WR;
1488 			}
1489 		}
1490 	}
1491 
1492 	if (how != -1) {
1493 		FDLOG(LOG_INFO, fd_cb, "sending close, how = %d", how);
1494 		if (flow_divert_send_close(fd_cb, how) != ENOBUFS) {
1495 			/* Successfully sent the close packet. Record the ways in which the tunnel has been closed */
1496 			if (how != SHUT_RD) {
1497 				fd_cb->flags |= FLOW_DIVERT_TUNNEL_WR_CLOSED;
1498 			}
1499 			if (how != SHUT_WR) {
1500 				fd_cb->flags |= FLOW_DIVERT_TUNNEL_RD_CLOSED;
1501 			}
1502 		}
1503 	}
1504 
1505 	if (flow_divert_tunnel_how_closed(fd_cb) == SHUT_RDWR) {
1506 		flow_divert_disconnect_socket(fd_cb->so, !(fd_cb->flags & FLOW_DIVERT_IMPLICIT_CONNECT));
1507 	}
1508 }
1509 
1510 static errno_t
flow_divert_send_data_packet(struct flow_divert_pcb * fd_cb,mbuf_t data,size_t data_len,struct sockaddr * toaddr,Boolean force)1511 flow_divert_send_data_packet(struct flow_divert_pcb *fd_cb, mbuf_t data, size_t data_len, struct sockaddr *toaddr, Boolean force)
1512 {
1513 	mbuf_t  packet = NULL;
1514 	mbuf_t  last = NULL;
1515 	int             error   = 0;
1516 
1517 	error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_DATA, &packet);
1518 	if (error || packet == NULL) {
1519 		FDLOG(LOG_ERR, fd_cb, "flow_divert_packet_init failed: %d", error);
1520 		goto done;
1521 	}
1522 
1523 	if (toaddr != NULL) {
1524 		error = flow_divert_append_target_endpoint_tlv(packet, toaddr);
1525 		if (error) {
1526 			FDLOG(LOG_ERR, fd_cb, "flow_divert_append_target_endpoint_tlv() failed: %d", error);
1527 			goto done;
1528 		}
1529 	}
1530 
1531 	if (data_len > 0 && data_len <= INT_MAX && data != NULL) {
1532 		last = m_last(packet);
1533 		mbuf_setnext(last, data);
1534 		mbuf_pkthdr_adjustlen(packet, (int)data_len);
1535 	} else {
1536 		data_len = 0;
1537 	}
1538 	error = flow_divert_send_packet(fd_cb, packet, force);
1539 	if (error == 0 && data_len > 0) {
1540 		fd_cb->bytes_sent += data_len;
1541 		flow_divert_add_data_statistics(fd_cb, data_len, TRUE);
1542 	}
1543 
1544 done:
1545 	if (error) {
1546 		if (last != NULL) {
1547 			mbuf_setnext(last, NULL);
1548 		}
1549 		if (packet != NULL) {
1550 			mbuf_freem(packet);
1551 		}
1552 	}
1553 
1554 	return error;
1555 }
1556 
1557 static void
flow_divert_send_buffered_data(struct flow_divert_pcb * fd_cb,Boolean force)1558 flow_divert_send_buffered_data(struct flow_divert_pcb *fd_cb, Boolean force)
1559 {
1560 	size_t  to_send;
1561 	size_t  sent    = 0;
1562 	int             error   = 0;
1563 	mbuf_t  buffer;
1564 
1565 	to_send = fd_cb->so->so_snd.sb_cc;
1566 	buffer = fd_cb->so->so_snd.sb_mb;
1567 
1568 	if (buffer == NULL && to_send > 0) {
1569 		FDLOG(LOG_ERR, fd_cb, "Send buffer is NULL, but size is supposed to be %lu", to_send);
1570 		return;
1571 	}
1572 
1573 	/* Ignore the send window if force is enabled */
1574 	if (!force && (to_send > fd_cb->send_window)) {
1575 		to_send = fd_cb->send_window;
1576 	}
1577 
1578 	if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
1579 		while (sent < to_send) {
1580 			mbuf_t  data;
1581 			size_t  data_len;
1582 
1583 			data_len = to_send - sent;
1584 			if (data_len > FLOW_DIVERT_CHUNK_SIZE) {
1585 				data_len = FLOW_DIVERT_CHUNK_SIZE;
1586 			}
1587 
1588 			error = mbuf_copym(buffer, sent, data_len, MBUF_DONTWAIT, &data);
1589 			if (error) {
1590 				FDLOG(LOG_ERR, fd_cb, "mbuf_copym failed: %d", error);
1591 				break;
1592 			}
1593 
1594 			error = flow_divert_send_data_packet(fd_cb, data, data_len, NULL, force);
1595 			if (error) {
1596 				if (data != NULL) {
1597 					mbuf_freem(data);
1598 				}
1599 				break;
1600 			}
1601 
1602 			sent += data_len;
1603 		}
1604 		sbdrop(&fd_cb->so->so_snd, (int)sent);
1605 		sowwakeup(fd_cb->so);
1606 	} else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
1607 		mbuf_t data;
1608 		mbuf_t m;
1609 		size_t data_len;
1610 
1611 		while (buffer) {
1612 			struct sockaddr *toaddr = flow_divert_get_buffered_target_address(buffer);
1613 
1614 			m = buffer;
1615 			if (toaddr != NULL) {
1616 				/* look for data in the chain */
1617 				do {
1618 					m = m->m_next;
1619 					if (m != NULL && m->m_type == MT_DATA) {
1620 						break;
1621 					}
1622 				} while (m);
1623 				if (m == NULL) {
1624 					/* unexpected */
1625 					FDLOG0(LOG_ERR, fd_cb, "failed to find type MT_DATA in the mbuf chain.");
1626 					goto move_on;
1627 				}
1628 			}
1629 			data_len = mbuf_pkthdr_len(m);
1630 			if (data_len > 0) {
1631 				FDLOG(LOG_DEBUG, fd_cb, "mbuf_copym() data_len = %lu", data_len);
1632 				error = mbuf_copym(m, 0, data_len, MBUF_DONTWAIT, &data);
1633 				if (error) {
1634 					FDLOG(LOG_ERR, fd_cb, "mbuf_copym failed: %d", error);
1635 					break;
1636 				}
1637 			} else {
1638 				data = NULL;
1639 			}
1640 			error = flow_divert_send_data_packet(fd_cb, data, data_len, toaddr, force);
1641 			if (error) {
1642 				if (data != NULL) {
1643 					mbuf_freem(data);
1644 				}
1645 				break;
1646 			}
1647 			sent += data_len;
1648 move_on:
1649 			buffer = buffer->m_nextpkt;
1650 			(void) sbdroprecord(&(fd_cb->so->so_snd));
1651 		}
1652 	}
1653 
1654 	if (sent > 0) {
1655 		FDLOG(LOG_DEBUG, fd_cb, "sent %lu bytes of buffered data", sent);
1656 		if (fd_cb->send_window >= sent) {
1657 			fd_cb->send_window -= sent;
1658 		} else {
1659 			fd_cb->send_window = 0;
1660 		}
1661 	}
1662 }
1663 
1664 static int
flow_divert_send_app_data(struct flow_divert_pcb * fd_cb,mbuf_t data,struct sockaddr * toaddr)1665 flow_divert_send_app_data(struct flow_divert_pcb *fd_cb, mbuf_t data, struct sockaddr *toaddr)
1666 {
1667 	size_t  to_send         = mbuf_pkthdr_len(data);
1668 	int     error           = 0;
1669 
1670 	if (to_send > fd_cb->send_window) {
1671 		to_send = fd_cb->send_window;
1672 	}
1673 
1674 	if (fd_cb->so->so_snd.sb_cc > 0) {
1675 		to_send = 0;    /* If the send buffer is non-empty, then we can't send anything */
1676 	}
1677 
1678 	if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
1679 		size_t  sent            = 0;
1680 		mbuf_t  remaining_data  = data;
1681 		mbuf_t  pkt_data        = NULL;
1682 		while (sent < to_send && remaining_data != NULL) {
1683 			size_t  pkt_data_len;
1684 
1685 			pkt_data = remaining_data;
1686 
1687 			if ((to_send - sent) > FLOW_DIVERT_CHUNK_SIZE) {
1688 				pkt_data_len = FLOW_DIVERT_CHUNK_SIZE;
1689 			} else {
1690 				pkt_data_len = to_send - sent;
1691 			}
1692 
1693 			if (pkt_data_len < mbuf_pkthdr_len(pkt_data)) {
1694 				error = mbuf_split(pkt_data, pkt_data_len, MBUF_DONTWAIT, &remaining_data);
1695 				if (error) {
1696 					FDLOG(LOG_ERR, fd_cb, "mbuf_split failed: %d", error);
1697 					pkt_data = NULL;
1698 					break;
1699 				}
1700 			} else {
1701 				remaining_data = NULL;
1702 			}
1703 
1704 			error = flow_divert_send_data_packet(fd_cb, pkt_data, pkt_data_len, NULL, FALSE);
1705 
1706 			if (error) {
1707 				break;
1708 			}
1709 
1710 			pkt_data = NULL;
1711 			sent += pkt_data_len;
1712 		}
1713 
1714 		fd_cb->send_window -= sent;
1715 
1716 		error = 0;
1717 
1718 		if (pkt_data != NULL) {
1719 			if (sbspace(&fd_cb->so->so_snd) > 0) {
1720 				if (!sbappendstream(&fd_cb->so->so_snd, pkt_data)) {
1721 					FDLOG(LOG_ERR, fd_cb, "sbappendstream failed with pkt_data, send buffer size = %u, send_window = %u\n",
1722 					    fd_cb->so->so_snd.sb_cc, fd_cb->send_window);
1723 				}
1724 			} else {
1725 				mbuf_freem(pkt_data);
1726 				error = ENOBUFS;
1727 			}
1728 		}
1729 
1730 		if (remaining_data != NULL) {
1731 			if (sbspace(&fd_cb->so->so_snd) > 0) {
1732 				if (!sbappendstream(&fd_cb->so->so_snd, remaining_data)) {
1733 					FDLOG(LOG_ERR, fd_cb, "sbappendstream failed with remaining_data, send buffer size = %u, send_window = %u\n",
1734 					    fd_cb->so->so_snd.sb_cc, fd_cb->send_window);
1735 				}
1736 			} else {
1737 				mbuf_freem(remaining_data);
1738 				error = ENOBUFS;
1739 			}
1740 		}
1741 	} else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
1742 		if (to_send || mbuf_pkthdr_len(data) == 0) {
1743 			error = flow_divert_send_data_packet(fd_cb, data, to_send, toaddr, FALSE);
1744 			if (error) {
1745 				FDLOG(LOG_ERR, fd_cb, "flow_divert_send_data_packet failed. send data size = %lu", to_send);
1746 				if (data != NULL) {
1747 					mbuf_freem(data);
1748 				}
1749 			} else {
1750 				fd_cb->send_window -= to_send;
1751 			}
1752 		} else {
1753 			/* buffer it */
1754 			if (sbspace(&fd_cb->so->so_snd) >= (int)mbuf_pkthdr_len(data)) {
1755 				if (toaddr != NULL) {
1756 					if (!sbappendaddr(&fd_cb->so->so_snd, toaddr, data, NULL, &error)) {
1757 						FDLOG(LOG_ERR, fd_cb,
1758 						    "sbappendaddr failed. send buffer size = %u, send_window = %u, error = %d\n",
1759 						    fd_cb->so->so_snd.sb_cc, fd_cb->send_window, error);
1760 					}
1761 					error = 0;
1762 				} else {
1763 					if (!sbappendrecord(&fd_cb->so->so_snd, data)) {
1764 						FDLOG(LOG_ERR, fd_cb,
1765 						    "sbappendrecord failed. send buffer size = %u, send_window = %u, error = %d\n",
1766 						    fd_cb->so->so_snd.sb_cc, fd_cb->send_window, error);
1767 					}
1768 				}
1769 			} else {
1770 				if (data != NULL) {
1771 					mbuf_freem(data);
1772 				}
1773 				error = ENOBUFS;
1774 			}
1775 		}
1776 	}
1777 
1778 	return error;
1779 }
1780 
1781 static int
flow_divert_send_read_notification(struct flow_divert_pcb * fd_cb)1782 flow_divert_send_read_notification(struct flow_divert_pcb *fd_cb)
1783 {
1784 	int error = 0;
1785 	mbuf_t packet = NULL;
1786 
1787 	error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_READ_NOTIFY, &packet);
1788 	if (error) {
1789 		FDLOG(LOG_ERR, fd_cb, "failed to create a read notification packet: %d", error);
1790 		goto done;
1791 	}
1792 
1793 	error = flow_divert_send_packet(fd_cb, packet, TRUE);
1794 	if (error) {
1795 		goto done;
1796 	}
1797 
1798 done:
1799 	if (error && packet != NULL) {
1800 		mbuf_free(packet);
1801 	}
1802 
1803 	return error;
1804 }
1805 
1806 static int
flow_divert_send_traffic_class_update(struct flow_divert_pcb * fd_cb,int traffic_class)1807 flow_divert_send_traffic_class_update(struct flow_divert_pcb *fd_cb, int traffic_class)
1808 {
1809 	int             error           = 0;
1810 	mbuf_t  packet          = NULL;
1811 
1812 	error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_PROPERTIES_UPDATE, &packet);
1813 	if (error) {
1814 		FDLOG(LOG_ERR, fd_cb, "failed to create a properties update packet: %d", error);
1815 		goto done;
1816 	}
1817 
1818 	error = flow_divert_packet_append_tlv(packet, FLOW_DIVERT_TLV_TRAFFIC_CLASS, sizeof(traffic_class), &traffic_class);
1819 	if (error) {
1820 		FDLOG(LOG_ERR, fd_cb, "failed to add the traffic class: %d", error);
1821 		goto done;
1822 	}
1823 
1824 	error = flow_divert_send_packet(fd_cb, packet, TRUE);
1825 	if (error) {
1826 		goto done;
1827 	}
1828 
1829 done:
1830 	if (error && packet != NULL) {
1831 		mbuf_free(packet);
1832 	}
1833 
1834 	return error;
1835 }
1836 
1837 static void
flow_divert_set_local_endpoint(struct flow_divert_pcb * fd_cb,struct sockaddr * local_endpoint)1838 flow_divert_set_local_endpoint(struct flow_divert_pcb *fd_cb, struct sockaddr *local_endpoint)
1839 {
1840 	struct inpcb *inp = sotoinpcb(fd_cb->so);
1841 
1842 	if (local_endpoint->sa_family == AF_INET6) {
1843 		if (IN6_IS_ADDR_UNSPECIFIED(&inp->in6p_laddr) && (fd_cb->flags & FLOW_DIVERT_SHOULD_SET_LOCAL_ADDR)) {
1844 			fd_cb->flags |= FLOW_DIVERT_DID_SET_LOCAL_ADDR;
1845 			inp->in6p_laddr = (satosin6(local_endpoint))->sin6_addr;
1846 			inp->inp_lifscope = (satosin6(local_endpoint))->sin6_scope_id;
1847 			in6_verify_ifscope(&inp->in6p_laddr, inp->inp_lifscope);
1848 		}
1849 		if (inp->inp_lport == 0) {
1850 			inp->inp_lport = (satosin6(local_endpoint))->sin6_port;
1851 		}
1852 	} else if (local_endpoint->sa_family == AF_INET) {
1853 		if (inp->inp_laddr.s_addr == INADDR_ANY && (fd_cb->flags & FLOW_DIVERT_SHOULD_SET_LOCAL_ADDR)) {
1854 			fd_cb->flags |= FLOW_DIVERT_DID_SET_LOCAL_ADDR;
1855 			inp->inp_laddr = (satosin(local_endpoint))->sin_addr;
1856 		}
1857 		if (inp->inp_lport == 0) {
1858 			inp->inp_lport = (satosin(local_endpoint))->sin_port;
1859 		}
1860 	}
1861 }
1862 
1863 static void
flow_divert_set_remote_endpoint(struct flow_divert_pcb * fd_cb,struct sockaddr * remote_endpoint)1864 flow_divert_set_remote_endpoint(struct flow_divert_pcb *fd_cb, struct sockaddr *remote_endpoint)
1865 {
1866 	struct inpcb *inp = sotoinpcb(fd_cb->so);
1867 
1868 	if (remote_endpoint->sa_family == AF_INET6) {
1869 		if (IN6_IS_ADDR_UNSPECIFIED(&inp->in6p_faddr)) {
1870 			inp->in6p_faddr = (satosin6(remote_endpoint))->sin6_addr;
1871 			inp->inp_fifscope = (satosin6(remote_endpoint))->sin6_scope_id;
1872 			in6_verify_ifscope(&inp->in6p_faddr, inp->inp_fifscope);
1873 		}
1874 		if (inp->inp_fport == 0) {
1875 			inp->inp_fport = (satosin6(remote_endpoint))->sin6_port;
1876 		}
1877 	} else if (remote_endpoint->sa_family == AF_INET) {
1878 		if (inp->inp_faddr.s_addr == INADDR_ANY) {
1879 			inp->inp_faddr = (satosin(remote_endpoint))->sin_addr;
1880 		}
1881 		if (inp->inp_fport == 0) {
1882 			inp->inp_fport = (satosin(remote_endpoint))->sin_port;
1883 		}
1884 	}
1885 }
1886 
1887 static uint32_t
flow_divert_derive_kernel_control_unit(uint32_t ctl_unit,uint32_t * aggregate_unit,bool * is_aggregate)1888 flow_divert_derive_kernel_control_unit(uint32_t ctl_unit, uint32_t *aggregate_unit, bool *is_aggregate)
1889 {
1890 	*is_aggregate = false;
1891 	if (aggregate_unit != NULL && *aggregate_unit != 0) {
1892 		uint32_t counter;
1893 		for (counter = 0; counter < (GROUP_COUNT_MAX - 1); counter++) {
1894 			if ((*aggregate_unit) & (1 << counter)) {
1895 				break;
1896 			}
1897 		}
1898 		if (counter < (GROUP_COUNT_MAX - 1)) {
1899 			*aggregate_unit &= ~(1 << counter);
1900 			*is_aggregate = true;
1901 			return counter + 1;
1902 		} else {
1903 			return ctl_unit;
1904 		}
1905 	} else {
1906 		return ctl_unit;
1907 	}
1908 }
1909 
1910 static int
flow_divert_try_next(struct flow_divert_pcb * fd_cb)1911 flow_divert_try_next(struct flow_divert_pcb *fd_cb)
1912 {
1913 	uint32_t current_ctl_unit = 0;
1914 	uint32_t next_ctl_unit = 0;
1915 	struct flow_divert_group *current_group = NULL;
1916 	struct flow_divert_group *next_group = NULL;
1917 	int error = 0;
1918 	bool is_aggregate = false;
1919 
1920 	next_ctl_unit = flow_divert_derive_kernel_control_unit(fd_cb->policy_control_unit, &(fd_cb->aggregate_unit), &is_aggregate);
1921 	current_ctl_unit = fd_cb->control_group_unit;
1922 
1923 	if (current_ctl_unit == next_ctl_unit) {
1924 		FDLOG0(LOG_NOTICE, fd_cb, "Next control unit is the same as the current control unit, disabling flow divert");
1925 		error = EALREADY;
1926 		goto done;
1927 	}
1928 
1929 	if (next_ctl_unit == 0 || next_ctl_unit >= GROUP_COUNT_MAX) {
1930 		FDLOG0(LOG_NOTICE, fd_cb, "No more valid control units, disabling flow divert");
1931 		error = ENOENT;
1932 		goto done;
1933 	}
1934 
1935 	if (g_flow_divert_groups == NULL || g_active_group_count == 0) {
1936 		FDLOG0(LOG_NOTICE, fd_cb, "No active groups, disabling flow divert");
1937 		error = ENOENT;
1938 		goto done;
1939 	}
1940 
1941 	next_group = g_flow_divert_groups[next_ctl_unit];
1942 	if (next_group == NULL) {
1943 		FDLOG(LOG_NOTICE, fd_cb, "Group for control unit %u does not exist", next_ctl_unit);
1944 		error = ENOENT;
1945 		goto done;
1946 	}
1947 
1948 	current_group = fd_cb->group;
1949 
1950 	lck_rw_lock_exclusive(&(current_group->lck));
1951 	lck_rw_lock_exclusive(&(next_group->lck));
1952 
1953 	FDLOG(LOG_NOTICE, fd_cb, "Moving from %u to %u", current_ctl_unit, next_ctl_unit);
1954 
1955 	RB_REMOVE(fd_pcb_tree, &(current_group->pcb_tree), fd_cb);
1956 	if (RB_INSERT(fd_pcb_tree, &(next_group->pcb_tree), fd_cb) != NULL) {
1957 		panic("group with unit %u already contains a connection with hash %u", next_ctl_unit, fd_cb->hash);
1958 	}
1959 
1960 	fd_cb->group = next_group;
1961 	fd_cb->control_group_unit = next_ctl_unit;
1962 	if (is_aggregate) {
1963 		fd_cb->flags |= FLOW_DIVERT_FLOW_IS_TRANSPARENT;
1964 	} else {
1965 		fd_cb->flags &= ~FLOW_DIVERT_FLOW_IS_TRANSPARENT;
1966 	}
1967 
1968 	lck_rw_done(&(next_group->lck));
1969 	lck_rw_done(&(current_group->lck));
1970 
1971 	error = flow_divert_send_connect_packet(fd_cb);
1972 	if (error) {
1973 		FDLOG(LOG_NOTICE, fd_cb, "Failed to send the connect packet to %u, disabling flow divert", next_ctl_unit);
1974 		error = ENOENT;
1975 		goto done;
1976 	}
1977 
1978 done:
1979 	return error;
1980 }
1981 
1982 static void
flow_divert_disable(struct flow_divert_pcb * fd_cb)1983 flow_divert_disable(struct flow_divert_pcb *fd_cb)
1984 {
1985 	struct socket *so = NULL;
1986 	mbuf_t  buffer;
1987 	int error = 0;
1988 	proc_t last_proc = NULL;
1989 	struct sockaddr *remote_endpoint = fd_cb->original_remote_endpoint;
1990 	bool do_connect = !(fd_cb->flags & FLOW_DIVERT_IMPLICIT_CONNECT);
1991 	bool unset_connecting = false;
1992 	struct inpcb *inp = NULL;
1993 
1994 	so = fd_cb->so;
1995 	if (so == NULL) {
1996 		goto done;
1997 	}
1998 
1999 	FDLOG0(LOG_NOTICE, fd_cb, "Skipped all flow divert services, disabling flow divert");
2000 
2001 	unset_connecting = (SOCK_TYPE(so) == SOCK_STREAM || (fd_cb->flags & FLOW_DIVERT_HAS_TOKEN));
2002 
2003 	/* Restore the IP state */
2004 	inp = sotoinpcb(so);
2005 	inp->inp_vflag = fd_cb->original_vflag;
2006 	inp->inp_faddr.s_addr = INADDR_ANY;
2007 	inp->inp_fport = 0;
2008 	memset(&(inp->in6p_faddr), 0, sizeof(inp->in6p_faddr));
2009 	inp->inp_fifscope = IFSCOPE_NONE;
2010 	inp->in6p_fport = 0;
2011 	/* If flow divert set the local address, clear it out */
2012 	if (fd_cb->flags & FLOW_DIVERT_DID_SET_LOCAL_ADDR) {
2013 		inp->inp_laddr.s_addr = INADDR_ANY;
2014 		memset(&(inp->in6p_laddr), 0, sizeof(inp->in6p_laddr));
2015 		inp->inp_lifscope = IFSCOPE_NONE;
2016 	}
2017 	inp->inp_last_outifp = fd_cb->original_last_outifp;
2018 	inp->in6p_last_outifp = fd_cb->original_last_outifp6;
2019 
2020 	/* Dis-associate the socket */
2021 	so->so_flags &= ~SOF_FLOW_DIVERT;
2022 	so->so_flags1 |= SOF1_FLOW_DIVERT_SKIP;
2023 	so->so_fd_pcb = NULL;
2024 	fd_cb->so = NULL;
2025 
2026 	/* Remove from the group */
2027 	flow_divert_pcb_remove(fd_cb);
2028 
2029 	FDRELEASE(fd_cb); /* Release the socket's reference */
2030 
2031 	/* Revert back to the original protocol */
2032 	so->so_proto = pffindproto(SOCK_DOM(so), SOCK_PROTO(so), SOCK_TYPE(so));
2033 
2034 	last_proc = proc_find(so->last_pid);
2035 
2036 	if (do_connect) {
2037 		if (unset_connecting) {
2038 			so->so_state &= ~SS_ISCONNECTING; /* Get out of the connecting state to avoid confusing NECP */
2039 		}
2040 		/* Connect using the original protocol */
2041 		error = (*so->so_proto->pr_usrreqs->pru_connect)(so, remote_endpoint, (last_proc != NULL ? last_proc : current_proc()));
2042 		if (error) {
2043 			FDLOG(LOG_ERR, fd_cb, "Failed to connect using the socket's original protocol: %d", error);
2044 			goto done;
2045 		}
2046 	}
2047 
2048 	buffer = so->so_snd.sb_mb;
2049 	if (buffer == NULL) {
2050 		/* No buffered data, done */
2051 		goto done;
2052 	}
2053 
2054 	/* Send any buffered data using the original protocol */
2055 	if (SOCK_TYPE(so) == SOCK_STREAM) {
2056 		mbuf_t data_to_send = NULL;
2057 		size_t data_len = so->so_snd.sb_cc;
2058 
2059 		error = mbuf_copym(buffer, 0, data_len, MBUF_DONTWAIT, &data_to_send);
2060 		if (error) {
2061 			FDLOG0(LOG_ERR, fd_cb, "Failed to copy the mbuf chain in the socket's send buffer");
2062 			goto done;
2063 		}
2064 
2065 		sbflush(&so->so_snd);
2066 
2067 		if (data_to_send->m_flags & M_PKTHDR) {
2068 			mbuf_pkthdr_setlen(data_to_send, data_len);
2069 		}
2070 
2071 		error = (*so->so_proto->pr_usrreqs->pru_send)(so,
2072 		    0,
2073 		    data_to_send,
2074 		    NULL,
2075 		    NULL,
2076 		    (last_proc != NULL ? last_proc : current_proc()));
2077 
2078 		if (error && error != EWOULDBLOCK) {
2079 			FDLOG(LOG_ERR, fd_cb, "Failed to send queued data using the socket's original protocol: %d", error);
2080 		} else {
2081 			error = 0;
2082 		}
2083 	} else if (SOCK_TYPE(so) == SOCK_DGRAM) {
2084 		struct sockbuf *sb = &so->so_snd;
2085 		MBUFQ_HEAD(send_queue_head) send_queue;
2086 		MBUFQ_INIT(&send_queue);
2087 
2088 		/* Flush the send buffer, moving all records to a temporary queue */
2089 		while (sb->sb_mb != NULL) {
2090 			mbuf_t record = sb->sb_mb;
2091 			mbuf_t m = record;
2092 			sb->sb_mb = sb->sb_mb->m_nextpkt;
2093 			while (m != NULL) {
2094 				sbfree(sb, m);
2095 				m = m->m_next;
2096 			}
2097 			record->m_nextpkt = NULL;
2098 			MBUFQ_ENQUEUE(&send_queue, record);
2099 		}
2100 		SB_EMPTY_FIXUP(sb);
2101 
2102 		while (!MBUFQ_EMPTY(&send_queue)) {
2103 			mbuf_t next_record = MBUFQ_FIRST(&send_queue);
2104 			mbuf_t addr = NULL;
2105 			mbuf_t control = NULL;
2106 			mbuf_t last_control = NULL;
2107 			mbuf_t data = NULL;
2108 			mbuf_t m = next_record;
2109 			struct sockaddr *to_endpoint = NULL;
2110 
2111 			MBUFQ_DEQUEUE(&send_queue, next_record);
2112 
2113 			while (m != NULL) {
2114 				if (m->m_type == MT_SONAME) {
2115 					addr = m;
2116 				} else if (m->m_type == MT_CONTROL) {
2117 					if (control == NULL) {
2118 						control = m;
2119 					}
2120 					last_control = m;
2121 				} else if (m->m_type == MT_DATA) {
2122 					data = m;
2123 					break;
2124 				}
2125 				m = m->m_next;
2126 			}
2127 
2128 			if (addr != NULL && !do_connect) {
2129 				to_endpoint = flow_divert_get_buffered_target_address(addr);
2130 				if (to_endpoint == NULL) {
2131 					FDLOG0(LOG_NOTICE, fd_cb, "Failed to get the remote address from the buffer");
2132 				}
2133 			}
2134 
2135 			if (data == NULL) {
2136 				FDLOG0(LOG_ERR, fd_cb, "Buffered record does not contain any data");
2137 				mbuf_freem(next_record);
2138 				continue;
2139 			}
2140 
2141 			if (!(data->m_flags & M_PKTHDR)) {
2142 				FDLOG0(LOG_ERR, fd_cb, "Buffered data does not have a packet header");
2143 				mbuf_freem(next_record);
2144 				continue;
2145 			}
2146 
2147 			if (addr != NULL) {
2148 				addr->m_next = NULL;
2149 			}
2150 
2151 			if (last_control != NULL) {
2152 				last_control->m_next = NULL;
2153 			}
2154 
2155 			error = (*so->so_proto->pr_usrreqs->pru_send)(so,
2156 			    0,
2157 			    data,
2158 			    to_endpoint,
2159 			    control,
2160 			    (last_proc != NULL ? last_proc : current_proc()));
2161 
2162 			if (addr != NULL) {
2163 				mbuf_freem(addr);
2164 			}
2165 
2166 			if (error) {
2167 				FDLOG(LOG_ERR, fd_cb, "Failed to send queued data using the socket's original protocol: %d", error);
2168 			}
2169 		}
2170 	}
2171 done:
2172 	if (last_proc != NULL) {
2173 		proc_rele(last_proc);
2174 	}
2175 
2176 	if (error && so != NULL) {
2177 		so->so_error = (uint16_t)error;
2178 		flow_divert_disconnect_socket(so, do_connect);
2179 	}
2180 }
2181 
2182 static void
flow_divert_scope(struct flow_divert_pcb * fd_cb,int out_if_index,bool derive_new_address)2183 flow_divert_scope(struct flow_divert_pcb *fd_cb, int out_if_index, bool derive_new_address)
2184 {
2185 	struct socket *so = NULL;
2186 	struct inpcb *inp = NULL;
2187 	struct ifnet *current_ifp = NULL;
2188 	struct ifnet *new_ifp = NULL;
2189 	int error = 0;
2190 
2191 	so = fd_cb->so;
2192 	if (so == NULL) {
2193 		return;
2194 	}
2195 
2196 	inp = sotoinpcb(so);
2197 
2198 	if (out_if_index <= 0) {
2199 		return;
2200 	}
2201 
2202 	if (inp->inp_vflag & INP_IPV6) {
2203 		current_ifp = inp->in6p_last_outifp;
2204 	} else {
2205 		current_ifp = inp->inp_last_outifp;
2206 	}
2207 
2208 	if (current_ifp != NULL) {
2209 		if (current_ifp->if_index == out_if_index) {
2210 			/* No change */
2211 			return;
2212 		}
2213 
2214 		/* Scope the socket to the given interface */
2215 		error = inp_bindif(inp, out_if_index, &new_ifp);
2216 		if (error != 0) {
2217 			FDLOG(LOG_ERR, fd_cb, "failed to scope to %d because inp_bindif returned %d", out_if_index, error);
2218 			return;
2219 		}
2220 
2221 		if (derive_new_address && fd_cb->original_remote_endpoint != NULL) {
2222 			/* Get the appropriate address for the given interface */
2223 			if (inp->inp_vflag & INP_IPV6) {
2224 				inp->in6p_laddr = sa6_any.sin6_addr;
2225 				error = in6_pcbladdr(inp, fd_cb->original_remote_endpoint, &(fd_cb->local_endpoint.sin6.sin6_addr), NULL);
2226 			} else {
2227 				inp->inp_laddr.s_addr = INADDR_ANY;
2228 				error = in_pcbladdr(inp, fd_cb->original_remote_endpoint, &(fd_cb->local_endpoint.sin.sin_addr), IFSCOPE_NONE, NULL, 0);
2229 			}
2230 
2231 			if (error != 0) {
2232 				FDLOG(LOG_WARNING, fd_cb, "failed to derive a new local address from %d because in_pcbladdr returned %d", out_if_index, error);
2233 			}
2234 		}
2235 	} else {
2236 		ifnet_head_lock_shared();
2237 		if (out_if_index <= if_index) {
2238 			new_ifp = ifindex2ifnet[out_if_index];
2239 		}
2240 		ifnet_head_done();
2241 	}
2242 
2243 	/* Update the "last interface" of the socket */
2244 	if (new_ifp != NULL) {
2245 		if (inp->inp_vflag & INP_IPV6) {
2246 			inp->in6p_last_outifp = new_ifp;
2247 		} else {
2248 			inp->inp_last_outifp = new_ifp;
2249 		}
2250 
2251 #if SKYWALK
2252 		if (NETNS_TOKEN_VALID(&inp->inp_netns_token)) {
2253 			netns_set_ifnet(&inp->inp_netns_token, new_ifp);
2254 		}
2255 #endif /* SKYWALK */
2256 	}
2257 }
2258 
2259 static void
flow_divert_handle_connect_result(struct flow_divert_pcb * fd_cb,mbuf_t packet,int offset)2260 flow_divert_handle_connect_result(struct flow_divert_pcb *fd_cb, mbuf_t packet, int offset)
2261 {
2262 	uint32_t                                        connect_error = 0;
2263 	uint32_t                                        ctl_unit                        = 0;
2264 	int                                                     error                           = 0;
2265 	struct flow_divert_group        *grp                            = NULL;
2266 	union sockaddr_in_4_6 local_endpoint = {};
2267 	union sockaddr_in_4_6 remote_endpoint = {};
2268 	int                                                     out_if_index            = 0;
2269 	uint32_t                                        send_window;
2270 	uint32_t                                        app_data_length         = 0;
2271 
2272 	memset(&local_endpoint, 0, sizeof(local_endpoint));
2273 	memset(&remote_endpoint, 0, sizeof(remote_endpoint));
2274 
2275 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_ERROR_CODE, sizeof(connect_error), &connect_error, NULL);
2276 	if (error) {
2277 		FDLOG(LOG_ERR, fd_cb, "failed to get the connect result: %d", error);
2278 		return;
2279 	}
2280 
2281 	connect_error = ntohl(connect_error);
2282 	FDLOG(LOG_INFO, fd_cb, "received connect result %u", connect_error);
2283 
2284 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_SPACE_AVAILABLE, sizeof(send_window), &send_window, NULL);
2285 	if (error) {
2286 		FDLOG(LOG_ERR, fd_cb, "failed to get the send window: %d", error);
2287 		return;
2288 	}
2289 
2290 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_CTL_UNIT, sizeof(ctl_unit), &ctl_unit, NULL);
2291 	if (error) {
2292 		FDLOG0(LOG_INFO, fd_cb, "No control unit provided in the connect result");
2293 	}
2294 
2295 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_LOCAL_ADDR, sizeof(local_endpoint), &(local_endpoint.sa), NULL);
2296 	if (error) {
2297 		FDLOG0(LOG_INFO, fd_cb, "No local address provided");
2298 	}
2299 
2300 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_REMOTE_ADDR, sizeof(remote_endpoint), &(remote_endpoint.sa), NULL);
2301 	if (error) {
2302 		FDLOG0(LOG_INFO, fd_cb, "No remote address provided");
2303 	}
2304 
2305 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_OUT_IF_INDEX, sizeof(out_if_index), &out_if_index, NULL);
2306 	if (error) {
2307 		FDLOG0(LOG_INFO, fd_cb, "No output if index provided");
2308 	}
2309 
2310 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_APP_DATA, 0, NULL, &app_data_length);
2311 	if (error) {
2312 		FDLOG0(LOG_INFO, fd_cb, "No application data provided in connect result");
2313 	}
2314 
2315 	error = 0;
2316 	ctl_unit                = ntohl(ctl_unit);
2317 
2318 	lck_rw_lock_shared(&g_flow_divert_group_lck);
2319 
2320 	if (connect_error == 0 && ctl_unit > 0) {
2321 		if (ctl_unit >= GROUP_COUNT_MAX) {
2322 			FDLOG(LOG_ERR, fd_cb, "Connect result contains an invalid control unit: %u", ctl_unit);
2323 			error = EINVAL;
2324 		} else if (g_flow_divert_groups == NULL || g_active_group_count == 0) {
2325 			FDLOG0(LOG_ERR, fd_cb, "No active groups, dropping connection");
2326 			error = EINVAL;
2327 		} else {
2328 			grp = g_flow_divert_groups[ctl_unit];
2329 			if (grp == NULL) {
2330 				error = ECONNRESET;
2331 			}
2332 		}
2333 	}
2334 
2335 	FDLOCK(fd_cb);
2336 	if (fd_cb->so != NULL) {
2337 		struct inpcb                            *inp = NULL;
2338 		struct flow_divert_group        *old_group;
2339 		struct socket *so = fd_cb->so;
2340 		bool local_address_is_valid = false;
2341 
2342 		socket_lock(so, 0);
2343 
2344 		if (!(so->so_flags & SOF_FLOW_DIVERT)) {
2345 			FDLOG0(LOG_NOTICE, fd_cb, "socket is not attached any more, ignoring connect result");
2346 			goto done;
2347 		}
2348 
2349 		if (SOCK_TYPE(so) == SOCK_STREAM && !(so->so_state & SS_ISCONNECTING)) {
2350 			FDLOG0(LOG_ERR, fd_cb, "TCP socket is not in the connecting state, ignoring connect result");
2351 			goto done;
2352 		}
2353 
2354 		inp = sotoinpcb(so);
2355 
2356 		if (connect_error || error) {
2357 			goto set_socket_state;
2358 		}
2359 
2360 		if (flow_divert_is_sockaddr_valid(&(local_endpoint.sa))) {
2361 			if (local_endpoint.sa.sa_family == AF_INET) {
2362 				local_endpoint.sa.sa_len = sizeof(struct sockaddr_in);
2363 				if ((inp->inp_vflag & INP_IPV4) && local_endpoint.sin.sin_addr.s_addr != INADDR_ANY) {
2364 					local_address_is_valid = true;
2365 					fd_cb->local_endpoint = local_endpoint;
2366 					inp->inp_laddr.s_addr = INADDR_ANY;
2367 				} else {
2368 					fd_cb->local_endpoint.sin.sin_port = local_endpoint.sin.sin_port;
2369 				}
2370 			} else if (local_endpoint.sa.sa_family == AF_INET6) {
2371 				local_endpoint.sa.sa_len = sizeof(struct sockaddr_in6);
2372 				if ((inp->inp_vflag & INP_IPV6) && !IN6_IS_ADDR_UNSPECIFIED(&local_endpoint.sin6.sin6_addr)) {
2373 					local_address_is_valid = true;
2374 					fd_cb->local_endpoint = local_endpoint;
2375 					inp->in6p_laddr = sa6_any.sin6_addr;
2376 				} else {
2377 					fd_cb->local_endpoint.sin6.sin6_port = local_endpoint.sin6.sin6_port;
2378 				}
2379 			}
2380 		}
2381 
2382 		flow_divert_scope(fd_cb, out_if_index, !local_address_is_valid);
2383 		flow_divert_set_local_endpoint(fd_cb, &(fd_cb->local_endpoint.sa));
2384 
2385 		if (flow_divert_is_sockaddr_valid(&(remote_endpoint.sa)) && SOCK_TYPE(so) == SOCK_STREAM) {
2386 			if (remote_endpoint.sa.sa_family == AF_INET) {
2387 				remote_endpoint.sa.sa_len = sizeof(struct sockaddr_in);
2388 			} else if (remote_endpoint.sa.sa_family == AF_INET6) {
2389 				remote_endpoint.sa.sa_len = sizeof(struct sockaddr_in6);
2390 			}
2391 			flow_divert_set_remote_endpoint(fd_cb, &(remote_endpoint.sa));
2392 		}
2393 
2394 		if (app_data_length > 0) {
2395 			uint8_t *app_data = NULL;
2396 			app_data = kalloc_data(app_data_length, Z_WAITOK);
2397 			if (app_data != NULL) {
2398 				error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_APP_DATA, app_data_length, app_data, NULL);
2399 				if (error == 0) {
2400 					FDLOG(LOG_INFO, fd_cb, "Got %u bytes of app data from the connect result", app_data_length);
2401 					if (fd_cb->app_data != NULL) {
2402 						kfree_data(fd_cb->app_data, fd_cb->app_data_length);
2403 					}
2404 					fd_cb->app_data = app_data;
2405 					fd_cb->app_data_length = app_data_length;
2406 				} else {
2407 					FDLOG(LOG_ERR, fd_cb, "Failed to copy %u bytes of application data from the connect result packet", app_data_length);
2408 					kfree_data(app_data, app_data_length);
2409 				}
2410 			} else {
2411 				FDLOG(LOG_ERR, fd_cb, "Failed to allocate a buffer of size %u to hold the application data from the connect result", app_data_length);
2412 			}
2413 		}
2414 
2415 		if (error) {
2416 			goto set_socket_state;
2417 		}
2418 
2419 		if (fd_cb->group == NULL) {
2420 			error = EINVAL;
2421 			goto set_socket_state;
2422 		}
2423 
2424 		if (grp != NULL) {
2425 			old_group = fd_cb->group;
2426 
2427 			lck_rw_lock_exclusive(&old_group->lck);
2428 			lck_rw_lock_exclusive(&grp->lck);
2429 
2430 			RB_REMOVE(fd_pcb_tree, &old_group->pcb_tree, fd_cb);
2431 			if (RB_INSERT(fd_pcb_tree, &grp->pcb_tree, fd_cb) != NULL) {
2432 				panic("group with unit %u already contains a connection with hash %u", grp->ctl_unit, fd_cb->hash);
2433 			}
2434 
2435 			fd_cb->group = grp;
2436 
2437 			lck_rw_done(&grp->lck);
2438 			lck_rw_done(&old_group->lck);
2439 		}
2440 
2441 		fd_cb->send_window = ntohl(send_window);
2442 
2443 set_socket_state:
2444 		if (!connect_error && !error) {
2445 			FDLOG0(LOG_INFO, fd_cb, "sending connect result");
2446 			error = flow_divert_send_connect_result(fd_cb);
2447 		}
2448 
2449 		if (connect_error || error) {
2450 			if (connect_error && fd_cb->control_group_unit != fd_cb->policy_control_unit) {
2451 				error = flow_divert_try_next(fd_cb);
2452 				if (error) {
2453 					flow_divert_disable(fd_cb);
2454 				}
2455 				goto done;
2456 			}
2457 
2458 			if (!connect_error) {
2459 				flow_divert_update_closed_state(fd_cb, SHUT_RDWR, FALSE);
2460 				so->so_error = (uint16_t)error;
2461 				flow_divert_send_close_if_needed(fd_cb);
2462 			} else {
2463 				flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
2464 				so->so_error = (uint16_t)connect_error;
2465 			}
2466 			flow_divert_disconnect_socket(so, !(fd_cb->flags & FLOW_DIVERT_IMPLICIT_CONNECT));
2467 		} else {
2468 #if NECP
2469 			/* Update NECP client with connected five-tuple */
2470 			if (!uuid_is_null(inp->necp_client_uuid)) {
2471 				socket_unlock(so, 0);
2472 				necp_client_assign_from_socket(so->last_pid, inp->necp_client_uuid, inp);
2473 				socket_lock(so, 0);
2474 			}
2475 #endif /* NECP */
2476 
2477 			flow_divert_send_buffered_data(fd_cb, FALSE);
2478 			soisconnected(so);
2479 		}
2480 
2481 		/* We don't need the connect packet any more */
2482 		if (fd_cb->connect_packet != NULL) {
2483 			mbuf_freem(fd_cb->connect_packet);
2484 			fd_cb->connect_packet = NULL;
2485 		}
2486 
2487 		/* We don't need the original remote endpoint any more */
2488 		free_sockaddr(fd_cb->original_remote_endpoint);
2489 done:
2490 		socket_unlock(so, 0);
2491 	}
2492 	FDUNLOCK(fd_cb);
2493 
2494 	lck_rw_done(&g_flow_divert_group_lck);
2495 }
2496 
2497 static void
flow_divert_handle_close(struct flow_divert_pcb * fd_cb,mbuf_t packet,int offset)2498 flow_divert_handle_close(struct flow_divert_pcb *fd_cb, mbuf_t packet, int offset)
2499 {
2500 	uint32_t        close_error                     = 0;
2501 	int                     error                   = 0;
2502 	int                     how                     = 0;
2503 
2504 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_ERROR_CODE, sizeof(close_error), &close_error, NULL);
2505 	if (error) {
2506 		FDLOG(LOG_ERR, fd_cb, "failed to get the close error: %d", error);
2507 		return;
2508 	}
2509 
2510 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_HOW, sizeof(how), &how, NULL);
2511 	if (error) {
2512 		FDLOG(LOG_ERR, fd_cb, "failed to get the close how flag: %d", error);
2513 		return;
2514 	}
2515 
2516 	how = ntohl(how);
2517 
2518 	FDLOG(LOG_INFO, fd_cb, "close received, how = %d", how);
2519 
2520 	FDLOCK(fd_cb);
2521 	if (fd_cb->so != NULL) {
2522 		bool is_connected = (SOCK_TYPE(fd_cb->so) == SOCK_STREAM || !(fd_cb->flags & FLOW_DIVERT_IMPLICIT_CONNECT));
2523 		socket_lock(fd_cb->so, 0);
2524 
2525 		if (!(fd_cb->so->so_flags & SOF_FLOW_DIVERT)) {
2526 			FDLOG0(LOG_NOTICE, fd_cb, "socket is not attached any more, ignoring close from provider");
2527 			goto done;
2528 		}
2529 
2530 		fd_cb->so->so_error = (uint16_t)ntohl(close_error);
2531 
2532 		flow_divert_update_closed_state(fd_cb, how, TRUE);
2533 
2534 		/* Only do this for stream flows because "shutdown by peer" doesn't make sense for datagram flows */
2535 		how = flow_divert_tunnel_how_closed(fd_cb);
2536 		if (how == SHUT_RDWR) {
2537 			flow_divert_disconnect_socket(fd_cb->so, is_connected);
2538 		} else if (how == SHUT_RD && is_connected) {
2539 			socantrcvmore(fd_cb->so);
2540 		} else if (how == SHUT_WR && is_connected) {
2541 			socantsendmore(fd_cb->so);
2542 		}
2543 done:
2544 		socket_unlock(fd_cb->so, 0);
2545 	}
2546 	FDUNLOCK(fd_cb);
2547 }
2548 
2549 static mbuf_t
flow_divert_create_control_mbuf(struct flow_divert_pcb * fd_cb)2550 flow_divert_create_control_mbuf(struct flow_divert_pcb *fd_cb)
2551 {
2552 	struct inpcb *inp = sotoinpcb(fd_cb->so);
2553 	bool need_recvdstaddr = false;
2554 	/* Socket flow tracking needs to see the local address */
2555 	need_recvdstaddr = SOFLOW_ENABLED(inp->inp_socket);
2556 	if ((inp->inp_vflag & INP_IPV4) &&
2557 	    fd_cb->local_endpoint.sa.sa_family == AF_INET &&
2558 	    ((inp->inp_flags & INP_RECVDSTADDR) || need_recvdstaddr)) {
2559 		return sbcreatecontrol((caddr_t)&(fd_cb->local_endpoint.sin.sin_addr), sizeof(struct in_addr), IP_RECVDSTADDR, IPPROTO_IP);
2560 	} else if ((inp->inp_vflag & INP_IPV6) &&
2561 	    fd_cb->local_endpoint.sa.sa_family == AF_INET6 &&
2562 	    ((inp->inp_flags & IN6P_PKTINFO) || need_recvdstaddr)) {
2563 		struct in6_pktinfo pi6;
2564 		memset(&pi6, 0, sizeof(pi6));
2565 		pi6.ipi6_addr = fd_cb->local_endpoint.sin6.sin6_addr;
2566 
2567 		return sbcreatecontrol((caddr_t)&pi6, sizeof(pi6), IPV6_PKTINFO, IPPROTO_IPV6);
2568 	}
2569 	return NULL;
2570 }
2571 
2572 static int
flow_divert_handle_data(struct flow_divert_pcb * fd_cb,mbuf_t packet,size_t offset)2573 flow_divert_handle_data(struct flow_divert_pcb *fd_cb, mbuf_t packet, size_t offset)
2574 {
2575 	int error = 0;
2576 
2577 	FDLOCK(fd_cb);
2578 	if (fd_cb->so != NULL) {
2579 		mbuf_t  data            = NULL;
2580 		size_t  data_size;
2581 		struct sockaddr_storage remote_address;
2582 		boolean_t got_remote_sa = FALSE;
2583 		boolean_t appended = FALSE;
2584 		boolean_t append_success = FALSE;
2585 
2586 		socket_lock(fd_cb->so, 0);
2587 
2588 		if (!(fd_cb->so->so_flags & SOF_FLOW_DIVERT)) {
2589 			FDLOG0(LOG_NOTICE, fd_cb, "socket is not attached any more, ignoring inbound data");
2590 			goto done;
2591 		}
2592 
2593 		if (sbspace(&fd_cb->so->so_rcv) == 0) {
2594 			error = ENOBUFS;
2595 			fd_cb->flags |= FLOW_DIVERT_NOTIFY_ON_RECEIVED;
2596 			FDLOG0(LOG_INFO, fd_cb, "Receive buffer is full, will send read notification when app reads some data");
2597 			goto done;
2598 		}
2599 
2600 		if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
2601 			uint32_t val_size = 0;
2602 
2603 			/* check if we got remote address with data */
2604 			memset(&remote_address, 0, sizeof(remote_address));
2605 			error = flow_divert_packet_get_tlv(packet, (int)offset, FLOW_DIVERT_TLV_REMOTE_ADDR, sizeof(remote_address), &remote_address, &val_size);
2606 			if (error || val_size > sizeof(remote_address)) {
2607 				FDLOG0(LOG_INFO, fd_cb, "No remote address provided");
2608 				error = 0;
2609 			} else {
2610 				if (remote_address.ss_len > sizeof(remote_address)) {
2611 					remote_address.ss_len = sizeof(remote_address);
2612 				}
2613 				/* validate the address */
2614 				if (flow_divert_is_sockaddr_valid((struct sockaddr *)&remote_address)) {
2615 					got_remote_sa = TRUE;
2616 				} else {
2617 					FDLOG0(LOG_INFO, fd_cb, "Remote address is invalid");
2618 				}
2619 				offset += (sizeof(uint8_t) + sizeof(uint32_t) + val_size);
2620 			}
2621 		}
2622 
2623 		data_size = (mbuf_pkthdr_len(packet) - offset);
2624 
2625 		if (fd_cb->so->so_state & SS_CANTRCVMORE) {
2626 			FDLOG(LOG_NOTICE, fd_cb, "app cannot receive any more data, dropping %lu bytes of data", data_size);
2627 			goto done;
2628 		}
2629 
2630 		if (SOCK_TYPE(fd_cb->so) != SOCK_STREAM && SOCK_TYPE(fd_cb->so) != SOCK_DGRAM) {
2631 			FDLOG(LOG_ERR, fd_cb, "socket has an unsupported type: %d", SOCK_TYPE(fd_cb->so));
2632 			goto done;
2633 		}
2634 
2635 		FDLOG(LOG_DEBUG, fd_cb, "received %lu bytes of data", data_size);
2636 
2637 		error = mbuf_split(packet, offset, MBUF_DONTWAIT, &data);
2638 		if (error || data == NULL) {
2639 			FDLOG(LOG_ERR, fd_cb, "mbuf_split failed: %d", error);
2640 			goto done;
2641 		}
2642 
2643 		if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
2644 			appended = (sbappendstream(&fd_cb->so->so_rcv, data) != 0);
2645 			append_success = TRUE;
2646 		} else {
2647 			struct sockaddr *append_sa = NULL;
2648 			mbuf_t mctl;
2649 
2650 			if (got_remote_sa == TRUE) {
2651 				error = flow_divert_dup_addr(remote_address.ss_family, (struct sockaddr *)&remote_address, &append_sa);
2652 			} else {
2653 				if (fd_cb->so->so_proto->pr_domain->dom_family == AF_INET6) {
2654 					error = in6_mapped_peeraddr(fd_cb->so, &append_sa);
2655 				} else {
2656 					error = in_getpeeraddr(fd_cb->so, &append_sa);
2657 				}
2658 			}
2659 			if (error) {
2660 				FDLOG0(LOG_ERR, fd_cb, "failed to dup the socket address.");
2661 			}
2662 
2663 			mctl = flow_divert_create_control_mbuf(fd_cb);
2664 			int append_error = 0;
2665 			appended = sbappendaddr(&fd_cb->so->so_rcv, append_sa, data, mctl, &append_error);
2666 			if (appended || append_error == 0) {
2667 				append_success = TRUE;
2668 			} else {
2669 				FDLOG(LOG_ERR, fd_cb, "failed to append %lu bytes of data: %d", data_size, append_error);
2670 			}
2671 
2672 			free_sockaddr(append_sa);
2673 		}
2674 
2675 		if (append_success) {
2676 			fd_cb->bytes_received += data_size;
2677 			flow_divert_add_data_statistics(fd_cb, data_size, FALSE);
2678 		}
2679 
2680 		if (appended) {
2681 			sorwakeup(fd_cb->so);
2682 		}
2683 done:
2684 		socket_unlock(fd_cb->so, 0);
2685 	}
2686 	FDUNLOCK(fd_cb);
2687 
2688 	return error;
2689 }
2690 
2691 static void
flow_divert_handle_read_notification(struct flow_divert_pcb * fd_cb,mbuf_t packet,int offset)2692 flow_divert_handle_read_notification(struct flow_divert_pcb *fd_cb, mbuf_t packet, int offset)
2693 {
2694 	uint32_t        read_count              = 0;
2695 	int             error                   = 0;
2696 
2697 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_READ_COUNT, sizeof(read_count), &read_count, NULL);
2698 	if (error) {
2699 		FDLOG(LOG_ERR, fd_cb, "failed to get the read count: %d", error);
2700 		return;
2701 	}
2702 
2703 	FDLOG(LOG_DEBUG, fd_cb, "received a read notification for %u bytes", ntohl(read_count));
2704 
2705 	FDLOCK(fd_cb);
2706 	if (fd_cb->so != NULL) {
2707 		socket_lock(fd_cb->so, 0);
2708 
2709 		if (!(fd_cb->so->so_flags & SOF_FLOW_DIVERT)) {
2710 			FDLOG0(LOG_NOTICE, fd_cb, "socket is not attached any more, ignoring read notification");
2711 			goto done;
2712 		}
2713 
2714 		fd_cb->send_window += ntohl(read_count);
2715 		flow_divert_send_buffered_data(fd_cb, FALSE);
2716 done:
2717 		socket_unlock(fd_cb->so, 0);
2718 	}
2719 	FDUNLOCK(fd_cb);
2720 }
2721 
2722 static void
flow_divert_handle_group_init(struct flow_divert_group * group,mbuf_t packet,int offset)2723 flow_divert_handle_group_init(struct flow_divert_group *group, mbuf_t packet, int offset)
2724 {
2725 	int error         = 0;
2726 	uint32_t key_size = 0;
2727 	int log_level     = 0;
2728 	uint32_t flags    = 0;
2729 
2730 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_TOKEN_KEY, 0, NULL, &key_size);
2731 	if (error) {
2732 		FDLOG(LOG_ERR, &nil_pcb, "failed to get the key size: %d", error);
2733 		return;
2734 	}
2735 
2736 	if (key_size == 0 || key_size > FLOW_DIVERT_MAX_KEY_SIZE) {
2737 		FDLOG(LOG_ERR, &nil_pcb, "Invalid key size: %u", key_size);
2738 		return;
2739 	}
2740 
2741 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_LOG_LEVEL, sizeof(log_level), &log_level, NULL);
2742 	if (!error) {
2743 		nil_pcb.log_level = (uint8_t)log_level;
2744 	}
2745 
2746 	lck_rw_lock_exclusive(&group->lck);
2747 
2748 	if (group->token_key != NULL) {
2749 		kfree_data(group->token_key, group->token_key_size);
2750 		group->token_key = NULL;
2751 	}
2752 
2753 	group->token_key = kalloc_data(key_size, Z_WAITOK);
2754 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_TOKEN_KEY, key_size, group->token_key, NULL);
2755 	if (error) {
2756 		FDLOG(LOG_ERR, &nil_pcb, "failed to get the token key: %d", error);
2757 		kfree_data(group->token_key, key_size);
2758 		group->token_key = NULL;
2759 		lck_rw_done(&group->lck);
2760 		return;
2761 	}
2762 
2763 	group->token_key_size = key_size;
2764 
2765 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_FLAGS, sizeof(flags), &flags, NULL);
2766 	if (!error) {
2767 		group->flags = flags;
2768 	}
2769 
2770 	lck_rw_done(&group->lck);
2771 }
2772 
2773 static void
flow_divert_handle_properties_update(struct flow_divert_pcb * fd_cb,mbuf_t packet,int offset)2774 flow_divert_handle_properties_update(struct flow_divert_pcb *fd_cb, mbuf_t packet, int offset)
2775 {
2776 	int                                                     error                           = 0;
2777 	int                                                     out_if_index            = 0;
2778 	uint32_t                                        app_data_length         = 0;
2779 
2780 	FDLOG0(LOG_INFO, fd_cb, "received a properties update");
2781 
2782 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_OUT_IF_INDEX, sizeof(out_if_index), &out_if_index, NULL);
2783 	if (error) {
2784 		FDLOG0(LOG_INFO, fd_cb, "No output if index provided in properties update");
2785 	}
2786 
2787 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_APP_DATA, 0, NULL, &app_data_length);
2788 	if (error) {
2789 		FDLOG0(LOG_INFO, fd_cb, "No application data provided in properties update");
2790 	}
2791 
2792 	FDLOCK(fd_cb);
2793 	if (fd_cb->so != NULL) {
2794 		socket_lock(fd_cb->so, 0);
2795 
2796 		if (!(fd_cb->so->so_flags & SOF_FLOW_DIVERT)) {
2797 			FDLOG0(LOG_NOTICE, fd_cb, "socket is not attached any more, ignoring properties update");
2798 			goto done;
2799 		}
2800 
2801 		if (out_if_index > 0) {
2802 			flow_divert_scope(fd_cb, out_if_index, true);
2803 			flow_divert_set_local_endpoint(fd_cb, &(fd_cb->local_endpoint.sa));
2804 		}
2805 
2806 		if (app_data_length > 0) {
2807 			uint8_t *app_data = NULL;
2808 			app_data = kalloc_data(app_data_length, Z_WAITOK);
2809 			if (app_data != NULL) {
2810 				error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_APP_DATA, app_data_length, app_data, NULL);
2811 				if (error == 0) {
2812 					if (fd_cb->app_data != NULL) {
2813 						kfree_data(fd_cb->app_data, fd_cb->app_data_length);
2814 					}
2815 					fd_cb->app_data = app_data;
2816 					fd_cb->app_data_length = app_data_length;
2817 				} else {
2818 					FDLOG(LOG_ERR, fd_cb, "Failed to copy %u bytes of application data from the properties update packet", app_data_length);
2819 					kfree_data(app_data, app_data_length);
2820 				}
2821 			} else {
2822 				FDLOG(LOG_ERR, fd_cb, "Failed to allocate a buffer of size %u to hold the application data from the properties update", app_data_length);
2823 			}
2824 		}
2825 done:
2826 		socket_unlock(fd_cb->so, 0);
2827 	}
2828 	FDUNLOCK(fd_cb);
2829 }
2830 
2831 static void
flow_divert_handle_app_map_create(struct flow_divert_group * group,mbuf_t packet,int offset)2832 flow_divert_handle_app_map_create(struct flow_divert_group *group, mbuf_t packet, int offset)
2833 {
2834 	size_t bytes_mem_size;
2835 	size_t child_maps_mem_size;
2836 	size_t nodes_mem_size;
2837 	size_t trie_memory_size = 0;
2838 	int cursor;
2839 	int error = 0;
2840 	struct flow_divert_trie new_trie;
2841 	int insert_error = 0;
2842 	int prefix_count = -1;
2843 	int signing_id_count = 0;
2844 	size_t bytes_count = 0;
2845 	size_t nodes_count = 0;
2846 	size_t maps_count = 0;
2847 
2848 	lck_rw_lock_exclusive(&group->lck);
2849 
2850 	/* Re-set the current trie */
2851 	if (group->signing_id_trie.memory != NULL) {
2852 		FREE(group->signing_id_trie.memory, M_TEMP);
2853 	}
2854 	memset(&group->signing_id_trie, 0, sizeof(group->signing_id_trie));
2855 	group->signing_id_trie.root = NULL_TRIE_IDX;
2856 
2857 	memset(&new_trie, 0, sizeof(new_trie));
2858 
2859 	/* Get the number of shared prefixes in the new set of signing ID strings */
2860 	error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_PREFIX_COUNT, sizeof(prefix_count), &prefix_count, NULL);
2861 
2862 	if (prefix_count < 0 || error) {
2863 		FDLOG(LOG_ERR, &nil_pcb, "Invalid prefix count (%d) or an error occurred while reading the prefix count: %d", prefix_count, error);
2864 		lck_rw_done(&group->lck);
2865 		return;
2866 	}
2867 
2868 	/* Compute the number of signing IDs and the total amount of bytes needed to store them */
2869 	for (cursor = flow_divert_packet_find_tlv(packet, offset, FLOW_DIVERT_TLV_SIGNING_ID, &error, 0);
2870 	    cursor >= 0;
2871 	    cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1)) {
2872 		uint32_t sid_size = 0;
2873 		error = flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
2874 		if (error || sid_size == 0) {
2875 			FDLOG(LOG_ERR, &nil_pcb, "Failed to get the length of the signing identifier at offset %d: %d", cursor, error);
2876 			signing_id_count = 0;
2877 			break;
2878 		}
2879 		if (os_add_overflow(bytes_count, sid_size, &bytes_count)) {
2880 			FDLOG0(LOG_ERR, &nil_pcb, "Overflow while incrementing number of bytes");
2881 			signing_id_count = 0;
2882 			break;
2883 		}
2884 		signing_id_count++;
2885 	}
2886 
2887 	if (signing_id_count == 0) {
2888 		lck_rw_done(&group->lck);
2889 		FDLOG0(LOG_NOTICE, &nil_pcb, "No signing identifiers");
2890 		return;
2891 	}
2892 
2893 	if (os_add3_overflow(prefix_count, signing_id_count, 1, &nodes_count)) { /* + 1 for the root node */
2894 		lck_rw_done(&group->lck);
2895 		FDLOG0(LOG_ERR, &nil_pcb, "Overflow while computing the number of nodes");
2896 		return;
2897 	}
2898 
2899 	if (os_add_overflow(prefix_count, 1, &maps_count)) { /* + 1 for the root node */
2900 		lck_rw_done(&group->lck);
2901 		FDLOG0(LOG_ERR, &nil_pcb, "Overflow while computing the number of maps");
2902 		return;
2903 	}
2904 
2905 	if (bytes_count > UINT16_MAX || nodes_count > UINT16_MAX || maps_count > UINT16_MAX) {
2906 		lck_rw_done(&group->lck);
2907 		FDLOG(LOG_NOTICE, &nil_pcb, "Invalid bytes count (%lu), nodes count (%lu) or maps count (%lu)", bytes_count, nodes_count, maps_count);
2908 		return;
2909 	}
2910 
2911 	FDLOG(LOG_INFO, &nil_pcb, "Nodes count = %lu, child maps count = %lu, bytes_count = %lu",
2912 	    nodes_count, maps_count, bytes_count);
2913 
2914 	if (os_mul_overflow(sizeof(*new_trie.nodes), (size_t)nodes_count, &nodes_mem_size) ||
2915 	    os_mul3_overflow(sizeof(*new_trie.child_maps), CHILD_MAP_SIZE, (size_t)maps_count, &child_maps_mem_size) ||
2916 	    os_mul_overflow(sizeof(*new_trie.bytes), (size_t)bytes_count, &bytes_mem_size) ||
2917 	    os_add3_overflow(nodes_mem_size, child_maps_mem_size, bytes_mem_size, &trie_memory_size)) {
2918 		FDLOG0(LOG_ERR, &nil_pcb, "Overflow while computing trie memory sizes");
2919 		lck_rw_done(&group->lck);
2920 		return;
2921 	}
2922 
2923 	if (trie_memory_size > FLOW_DIVERT_MAX_TRIE_MEMORY) {
2924 		FDLOG(LOG_ERR, &nil_pcb, "Trie memory size (%lu) is too big (maximum is %u)", trie_memory_size, FLOW_DIVERT_MAX_TRIE_MEMORY);
2925 		lck_rw_done(&group->lck);
2926 		return;
2927 	}
2928 
2929 	MALLOC(new_trie.memory, void *, trie_memory_size, M_TEMP, M_WAITOK);
2930 	if (new_trie.memory == NULL) {
2931 		FDLOG(LOG_ERR, &nil_pcb, "Failed to allocate %lu bytes of memory for the signing ID trie",
2932 		    nodes_mem_size + child_maps_mem_size + bytes_mem_size);
2933 		lck_rw_done(&group->lck);
2934 		return;
2935 	}
2936 
2937 	new_trie.bytes_count = (uint16_t)bytes_count;
2938 	new_trie.nodes_count = (uint16_t)nodes_count;
2939 	new_trie.child_maps_count = (uint16_t)maps_count;
2940 
2941 	/* Initialize the free lists */
2942 	new_trie.nodes = (struct flow_divert_trie_node *)new_trie.memory;
2943 	new_trie.nodes_free_next = 0;
2944 	memset(new_trie.nodes, 0, nodes_mem_size);
2945 
2946 	new_trie.child_maps = (uint16_t *)(void *)((uint8_t *)new_trie.memory + nodes_mem_size);
2947 	new_trie.child_maps_free_next = 0;
2948 	memset(new_trie.child_maps, 0xff, child_maps_mem_size);
2949 
2950 	new_trie.bytes = (uint8_t *)(void *)((uint8_t *)new_trie.memory + nodes_mem_size + child_maps_mem_size);
2951 	new_trie.bytes_free_next = 0;
2952 	memset(new_trie.bytes, 0, bytes_mem_size);
2953 
2954 	/* The root is an empty node */
2955 	new_trie.root = trie_node_alloc(&new_trie);
2956 
2957 	/* Add each signing ID to the trie */
2958 	for (cursor = flow_divert_packet_find_tlv(packet, offset, FLOW_DIVERT_TLV_SIGNING_ID, &error, 0);
2959 	    cursor >= 0;
2960 	    cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1)) {
2961 		uint32_t sid_size = 0;
2962 		error = flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
2963 		if (error || sid_size == 0) {
2964 			FDLOG(LOG_ERR, &nil_pcb, "Failed to get the length of the signing identifier at offset %d while building: %d", cursor, error);
2965 			insert_error = EINVAL;
2966 			break;
2967 		}
2968 		if (sid_size <= UINT16_MAX && new_trie.bytes_free_next + (uint16_t)sid_size <= new_trie.bytes_count) {
2969 			uint16_t new_node_idx;
2970 			error = flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, sid_size, &TRIE_BYTE(&new_trie, new_trie.bytes_free_next), NULL);
2971 			if (error) {
2972 				FDLOG(LOG_ERR, &nil_pcb, "Failed to read the signing identifier at offset %d: %d", cursor, error);
2973 				insert_error = EINVAL;
2974 				break;
2975 			}
2976 			new_node_idx = flow_divert_trie_insert(&new_trie, new_trie.bytes_free_next, sid_size);
2977 			if (new_node_idx == NULL_TRIE_IDX) {
2978 				insert_error = EINVAL;
2979 				break;
2980 			}
2981 		} else {
2982 			FDLOG0(LOG_ERR, &nil_pcb, "No place to put signing ID for insertion");
2983 			insert_error = ENOBUFS;
2984 			break;
2985 		}
2986 	}
2987 
2988 	if (!insert_error) {
2989 		group->signing_id_trie = new_trie;
2990 	} else {
2991 		FREE(new_trie.memory, M_TEMP);
2992 	}
2993 
2994 	lck_rw_done(&group->lck);
2995 }
2996 
2997 static int
flow_divert_input(mbuf_t packet,struct flow_divert_group * group)2998 flow_divert_input(mbuf_t packet, struct flow_divert_group *group)
2999 {
3000 	struct flow_divert_packet_header        hdr;
3001 	int                                                                     error           = 0;
3002 	struct flow_divert_pcb                          *fd_cb;
3003 
3004 	if (mbuf_pkthdr_len(packet) < sizeof(hdr)) {
3005 		FDLOG(LOG_ERR, &nil_pcb, "got a bad packet, length (%lu) < sizeof hdr (%lu)", mbuf_pkthdr_len(packet), sizeof(hdr));
3006 		error = EINVAL;
3007 		goto done;
3008 	}
3009 
3010 	if (mbuf_pkthdr_len(packet) > FD_CTL_RCVBUFF_SIZE) {
3011 		FDLOG(LOG_ERR, &nil_pcb, "got a bad packet, length (%lu) > %d", mbuf_pkthdr_len(packet), FD_CTL_RCVBUFF_SIZE);
3012 		error = EINVAL;
3013 		goto done;
3014 	}
3015 
3016 	error = mbuf_copydata(packet, 0, sizeof(hdr), &hdr);
3017 	if (error) {
3018 		FDLOG(LOG_ERR, &nil_pcb, "mbuf_copydata failed for the header: %d", error);
3019 		error = ENOBUFS;
3020 		goto done;
3021 	}
3022 
3023 	hdr.conn_id = ntohl(hdr.conn_id);
3024 
3025 	if (hdr.conn_id == 0) {
3026 		switch (hdr.packet_type) {
3027 		case FLOW_DIVERT_PKT_GROUP_INIT:
3028 			flow_divert_handle_group_init(group, packet, sizeof(hdr));
3029 			break;
3030 		case FLOW_DIVERT_PKT_APP_MAP_CREATE:
3031 			flow_divert_handle_app_map_create(group, packet, sizeof(hdr));
3032 			break;
3033 		default:
3034 			FDLOG(LOG_WARNING, &nil_pcb, "got an unknown message type: %d", hdr.packet_type);
3035 			break;
3036 		}
3037 		goto done;
3038 	}
3039 
3040 	fd_cb = flow_divert_pcb_lookup(hdr.conn_id, group);             /* This retains the PCB */
3041 	if (fd_cb == NULL) {
3042 		if (hdr.packet_type != FLOW_DIVERT_PKT_CLOSE && hdr.packet_type != FLOW_DIVERT_PKT_READ_NOTIFY) {
3043 			FDLOG(LOG_NOTICE, &nil_pcb, "got a %s message from group %d for an unknown pcb: %u", flow_divert_packet_type2str(hdr.packet_type), group->ctl_unit, hdr.conn_id);
3044 		}
3045 		goto done;
3046 	}
3047 
3048 	switch (hdr.packet_type) {
3049 	case FLOW_DIVERT_PKT_CONNECT_RESULT:
3050 		flow_divert_handle_connect_result(fd_cb, packet, sizeof(hdr));
3051 		break;
3052 	case FLOW_DIVERT_PKT_CLOSE:
3053 		flow_divert_handle_close(fd_cb, packet, sizeof(hdr));
3054 		break;
3055 	case FLOW_DIVERT_PKT_DATA:
3056 		error = flow_divert_handle_data(fd_cb, packet, sizeof(hdr));
3057 		break;
3058 	case FLOW_DIVERT_PKT_READ_NOTIFY:
3059 		flow_divert_handle_read_notification(fd_cb, packet, sizeof(hdr));
3060 		break;
3061 	case FLOW_DIVERT_PKT_PROPERTIES_UPDATE:
3062 		flow_divert_handle_properties_update(fd_cb, packet, sizeof(hdr));
3063 		break;
3064 	default:
3065 		FDLOG(LOG_WARNING, fd_cb, "got an unknown message type: %d", hdr.packet_type);
3066 		break;
3067 	}
3068 
3069 	FDRELEASE(fd_cb);
3070 
3071 done:
3072 	mbuf_freem(packet);
3073 	return error;
3074 }
3075 
3076 static void
flow_divert_close_all(struct flow_divert_group * group)3077 flow_divert_close_all(struct flow_divert_group *group)
3078 {
3079 	struct flow_divert_pcb                  *fd_cb;
3080 	SLIST_HEAD(, flow_divert_pcb)   tmp_list;
3081 
3082 	SLIST_INIT(&tmp_list);
3083 
3084 	lck_rw_lock_exclusive(&group->lck);
3085 
3086 	MBUFQ_DRAIN(&group->send_queue);
3087 
3088 	RB_FOREACH(fd_cb, fd_pcb_tree, &group->pcb_tree) {
3089 		FDRETAIN(fd_cb);
3090 		SLIST_INSERT_HEAD(&tmp_list, fd_cb, tmp_list_entry);
3091 	}
3092 
3093 	lck_rw_done(&group->lck);
3094 
3095 	while (!SLIST_EMPTY(&tmp_list)) {
3096 		fd_cb = SLIST_FIRST(&tmp_list);
3097 		FDLOCK(fd_cb);
3098 		SLIST_REMOVE_HEAD(&tmp_list, tmp_list_entry);
3099 		if (fd_cb->so != NULL) {
3100 			socket_lock(fd_cb->so, 0);
3101 			flow_divert_pcb_remove(fd_cb);
3102 			flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
3103 			fd_cb->so->so_error = ECONNABORTED;
3104 			flow_divert_disconnect_socket(fd_cb->so, !(fd_cb->flags & FLOW_DIVERT_IMPLICIT_CONNECT));
3105 			socket_unlock(fd_cb->so, 0);
3106 		}
3107 		FDUNLOCK(fd_cb);
3108 		FDRELEASE(fd_cb);
3109 	}
3110 }
3111 
3112 void
flow_divert_detach(struct socket * so)3113 flow_divert_detach(struct socket *so)
3114 {
3115 	struct flow_divert_pcb  *fd_cb          = so->so_fd_pcb;
3116 
3117 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3118 
3119 	so->so_flags &= ~SOF_FLOW_DIVERT;
3120 	so->so_fd_pcb = NULL;
3121 
3122 	FDLOG(LOG_INFO, fd_cb, "Detaching, ref count = %d", fd_cb->ref_count);
3123 
3124 	if (fd_cb->group != NULL) {
3125 		/* Last-ditch effort to send any buffered data */
3126 		flow_divert_send_buffered_data(fd_cb, TRUE);
3127 
3128 		flow_divert_update_closed_state(fd_cb, SHUT_RDWR, FALSE);
3129 		flow_divert_send_close_if_needed(fd_cb);
3130 		/* Remove from the group */
3131 		flow_divert_pcb_remove(fd_cb);
3132 	}
3133 
3134 	socket_unlock(so, 0);
3135 	FDLOCK(fd_cb);
3136 	fd_cb->so = NULL;
3137 	FDUNLOCK(fd_cb);
3138 	socket_lock(so, 0);
3139 
3140 	FDRELEASE(fd_cb);       /* Release the socket's reference */
3141 }
3142 
3143 static int
flow_divert_close(struct socket * so)3144 flow_divert_close(struct socket *so)
3145 {
3146 	struct flow_divert_pcb  *fd_cb          = so->so_fd_pcb;
3147 
3148 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3149 
3150 	FDLOG0(LOG_INFO, fd_cb, "Closing");
3151 
3152 	if (SOCK_TYPE(so) == SOCK_STREAM) {
3153 		soisdisconnecting(so);
3154 		sbflush(&so->so_rcv);
3155 	}
3156 
3157 	flow_divert_send_buffered_data(fd_cb, TRUE);
3158 	flow_divert_update_closed_state(fd_cb, SHUT_RDWR, FALSE);
3159 	flow_divert_send_close_if_needed(fd_cb);
3160 
3161 	/* Remove from the group */
3162 	flow_divert_pcb_remove(fd_cb);
3163 
3164 	return 0;
3165 }
3166 
3167 static int
flow_divert_disconnectx(struct socket * so,sae_associd_t aid,sae_connid_t cid __unused)3168 flow_divert_disconnectx(struct socket *so, sae_associd_t aid,
3169     sae_connid_t cid __unused)
3170 {
3171 	if (aid != SAE_ASSOCID_ANY && aid != SAE_ASSOCID_ALL) {
3172 		return EINVAL;
3173 	}
3174 
3175 	return flow_divert_close(so);
3176 }
3177 
3178 static int
flow_divert_shutdown(struct socket * so)3179 flow_divert_shutdown(struct socket *so)
3180 {
3181 	struct flow_divert_pcb  *fd_cb          = so->so_fd_pcb;
3182 
3183 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3184 
3185 	FDLOG0(LOG_INFO, fd_cb, "Can't send more");
3186 
3187 	socantsendmore(so);
3188 
3189 	flow_divert_update_closed_state(fd_cb, SHUT_WR, FALSE);
3190 	flow_divert_send_close_if_needed(fd_cb);
3191 
3192 	return 0;
3193 }
3194 
3195 static int
flow_divert_rcvd(struct socket * so,int flags __unused)3196 flow_divert_rcvd(struct socket *so, int flags __unused)
3197 {
3198 	struct flow_divert_pcb  *fd_cb = so->so_fd_pcb;
3199 	int space = sbspace(&so->so_rcv);
3200 
3201 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3202 
3203 	FDLOG(LOG_DEBUG, fd_cb, "app read bytes, space = %d", space);
3204 	if ((fd_cb->flags & FLOW_DIVERT_NOTIFY_ON_RECEIVED) &&
3205 	    (space > 0) &&
3206 	    flow_divert_send_read_notification(fd_cb) == 0) {
3207 		FDLOG0(LOG_INFO, fd_cb, "Sent a read notification");
3208 		fd_cb->flags &= ~FLOW_DIVERT_NOTIFY_ON_RECEIVED;
3209 	}
3210 
3211 	return 0;
3212 }
3213 
3214 static int
flow_divert_append_target_endpoint_tlv(mbuf_t connect_packet,struct sockaddr * toaddr)3215 flow_divert_append_target_endpoint_tlv(mbuf_t connect_packet, struct sockaddr *toaddr)
3216 {
3217 	int error = 0;
3218 	int port  = 0;
3219 
3220 	if (!flow_divert_is_sockaddr_valid(toaddr)) {
3221 		FDLOG(LOG_ERR, &nil_pcb, "Invalid target address, family = %u, length = %u", toaddr->sa_family, toaddr->sa_len);
3222 		error = EINVAL;
3223 		goto done;
3224 	}
3225 
3226 	error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_ADDRESS, toaddr->sa_len, toaddr);
3227 	if (error) {
3228 		goto done;
3229 	}
3230 
3231 	if (toaddr->sa_family == AF_INET) {
3232 		port = ntohs((satosin(toaddr))->sin_port);
3233 	} else {
3234 		port = ntohs((satosin6(toaddr))->sin6_port);
3235 	}
3236 
3237 	error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_PORT, sizeof(port), &port);
3238 	if (error) {
3239 		goto done;
3240 	}
3241 
3242 done:
3243 	return error;
3244 }
3245 
3246 struct sockaddr *
flow_divert_get_buffered_target_address(mbuf_t buffer)3247 flow_divert_get_buffered_target_address(mbuf_t buffer)
3248 {
3249 	if (buffer != NULL && buffer->m_type == MT_SONAME) {
3250 		struct sockaddr *toaddr = mtod(buffer, struct sockaddr *);
3251 		if (toaddr != NULL && flow_divert_is_sockaddr_valid(toaddr)) {
3252 			return toaddr;
3253 		}
3254 	}
3255 	return NULL;
3256 }
3257 
3258 static boolean_t
flow_divert_is_sockaddr_valid(struct sockaddr * addr)3259 flow_divert_is_sockaddr_valid(struct sockaddr *addr)
3260 {
3261 	switch (addr->sa_family) {
3262 	case AF_INET:
3263 		if (addr->sa_len < sizeof(struct sockaddr_in)) {
3264 			return FALSE;
3265 		}
3266 		break;
3267 	case AF_INET6:
3268 		if (addr->sa_len < sizeof(struct sockaddr_in6)) {
3269 			return FALSE;
3270 		}
3271 		break;
3272 	default:
3273 		return FALSE;
3274 	}
3275 	return TRUE;
3276 }
3277 
3278 static errno_t
flow_divert_dup_addr(sa_family_t family,struct sockaddr * addr,struct sockaddr ** dup)3279 flow_divert_dup_addr(sa_family_t family, struct sockaddr *addr,
3280     struct sockaddr **dup)
3281 {
3282 	int                                             error           = 0;
3283 	struct sockaddr                 *result;
3284 	struct sockaddr_storage ss;
3285 
3286 	if (addr != NULL) {
3287 		result = addr;
3288 	} else {
3289 		memset(&ss, 0, sizeof(ss));
3290 		ss.ss_family = family;
3291 		if (ss.ss_family == AF_INET) {
3292 			ss.ss_len = sizeof(struct sockaddr_in);
3293 		} else if (ss.ss_family == AF_INET6) {
3294 			ss.ss_len = sizeof(struct sockaddr_in6);
3295 		} else {
3296 			error = EINVAL;
3297 		}
3298 		result = (struct sockaddr *)&ss;
3299 	}
3300 
3301 	if (!error) {
3302 		*dup = dup_sockaddr(result, 1);
3303 		if (*dup == NULL) {
3304 			error = ENOBUFS;
3305 		}
3306 	}
3307 
3308 	return error;
3309 }
3310 
3311 static void
flow_divert_disconnect_socket(struct socket * so,bool is_connected)3312 flow_divert_disconnect_socket(struct socket *so, bool is_connected)
3313 {
3314 	if (SOCK_TYPE(so) == SOCK_STREAM || is_connected) {
3315 		soisdisconnected(so);
3316 	}
3317 	if (SOCK_TYPE(so) == SOCK_DGRAM) {
3318 		struct inpcb *inp = sotoinpcb(so);
3319 		if (inp != NULL && !(so->so_flags & SOF_PCBCLEARING)) {
3320 			/*
3321 			 * Let NetworkStatistics know this PCB is going away
3322 			 * before we detach it.
3323 			 */
3324 			if (nstat_collect && (SOCK_PROTO(so) == IPPROTO_TCP || SOCK_PROTO(so) == IPPROTO_UDP)) {
3325 				nstat_pcb_detach(inp);
3326 			}
3327 
3328 			if (SOCK_DOM(so) == PF_INET6) {
3329 				ROUTE_RELEASE(&inp->in6p_route);
3330 			} else {
3331 				ROUTE_RELEASE(&inp->inp_route);
3332 			}
3333 			inp->inp_state = INPCB_STATE_DEAD;
3334 			/* makes sure we're not called twice from so_close */
3335 			so->so_flags |= SOF_PCBCLEARING;
3336 			inpcb_gc_sched(inp->inp_pcbinfo, INPCB_TIMER_FAST);
3337 		}
3338 	}
3339 }
3340 
3341 static errno_t
flow_divert_ctloutput(struct socket * so,struct sockopt * sopt)3342 flow_divert_ctloutput(struct socket *so, struct sockopt *sopt)
3343 {
3344 	struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
3345 
3346 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3347 
3348 	if (sopt->sopt_name == SO_TRAFFIC_CLASS) {
3349 		if (sopt->sopt_dir == SOPT_SET && fd_cb->flags & FLOW_DIVERT_CONNECT_STARTED) {
3350 			flow_divert_send_traffic_class_update(fd_cb, so->so_traffic_class);
3351 		}
3352 	}
3353 
3354 	if (SOCK_DOM(so) == PF_INET) {
3355 		return g_tcp_protosw->pr_ctloutput(so, sopt);
3356 	} else if (SOCK_DOM(so) == PF_INET6) {
3357 		return g_tcp6_protosw->pr_ctloutput(so, sopt);
3358 	}
3359 	return 0;
3360 }
3361 
3362 static errno_t
flow_divert_connect_out_internal(struct socket * so,struct sockaddr * to,proc_t p,bool implicit)3363 flow_divert_connect_out_internal(struct socket *so, struct sockaddr *to, proc_t p, bool implicit)
3364 {
3365 	struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
3366 	int                                             error   = 0;
3367 	struct inpcb                    *inp    = sotoinpcb(so);
3368 	struct sockaddr_in              *sinp;
3369 	mbuf_t                                  connect_packet = NULL;
3370 	int                                             do_send = 1;
3371 
3372 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3373 
3374 	if (fd_cb->group == NULL) {
3375 		error = ENETUNREACH;
3376 		goto done;
3377 	}
3378 
3379 	if (inp == NULL) {
3380 		error = EINVAL;
3381 		goto done;
3382 	} else if (inp->inp_state == INPCB_STATE_DEAD) {
3383 		if (so->so_error) {
3384 			error = so->so_error;
3385 			so->so_error = 0;
3386 		} else {
3387 			error = EINVAL;
3388 		}
3389 		goto done;
3390 	}
3391 
3392 	if (fd_cb->flags & FLOW_DIVERT_CONNECT_STARTED) {
3393 		error = EALREADY;
3394 		goto done;
3395 	}
3396 
3397 	FDLOG0(LOG_INFO, fd_cb, "Connecting");
3398 
3399 	if (fd_cb->connect_packet == NULL) {
3400 		struct sockaddr_in sin = {};
3401 		struct ifnet *ifp = NULL;
3402 
3403 		if (to == NULL) {
3404 			FDLOG0(LOG_ERR, fd_cb, "No destination address available when creating connect packet");
3405 			error = EINVAL;
3406 			goto done;
3407 		}
3408 
3409 		fd_cb->original_remote_endpoint = dup_sockaddr(to, 0);
3410 		if (fd_cb->original_remote_endpoint == NULL) {
3411 			FDLOG0(LOG_ERR, fd_cb, "Failed to dup the remote endpoint");
3412 			error = ENOMEM;
3413 			goto done;
3414 		}
3415 		fd_cb->original_vflag = inp->inp_vflag;
3416 		fd_cb->original_last_outifp = inp->inp_last_outifp;
3417 		fd_cb->original_last_outifp6 = inp->in6p_last_outifp;
3418 
3419 		sinp = (struct sockaddr_in *)(void *)to;
3420 		if (sinp->sin_family == AF_INET && IN_MULTICAST(ntohl(sinp->sin_addr.s_addr))) {
3421 			error = EAFNOSUPPORT;
3422 			goto done;
3423 		}
3424 
3425 		if (to->sa_family == AF_INET6 && !(inp->inp_flags & IN6P_IPV6_V6ONLY)) {
3426 			struct sockaddr_in6 sin6 = {};
3427 			sin6.sin6_family = AF_INET6;
3428 			sin6.sin6_len = sizeof(struct sockaddr_in6);
3429 			sin6.sin6_port = satosin6(to)->sin6_port;
3430 			sin6.sin6_addr = satosin6(to)->sin6_addr;
3431 			if (IN6_IS_ADDR_V4MAPPED(&(sin6.sin6_addr))) {
3432 				in6_sin6_2_sin(&sin, &sin6);
3433 				to = (struct sockaddr *)&sin;
3434 			}
3435 		}
3436 
3437 		if (to->sa_family == AF_INET6) {
3438 			inp->inp_vflag &= ~INP_IPV4;
3439 			inp->inp_vflag |= INP_IPV6;
3440 			fd_cb->local_endpoint.sin6.sin6_len = sizeof(struct sockaddr_in6);
3441 			fd_cb->local_endpoint.sin6.sin6_family = AF_INET6;
3442 			fd_cb->local_endpoint.sin6.sin6_port = inp->inp_lport;
3443 			error = in6_pcbladdr(inp, to, &(fd_cb->local_endpoint.sin6.sin6_addr), &ifp);
3444 			if (error) {
3445 				FDLOG(LOG_WARNING, fd_cb, "failed to get a local IPv6 address: %d", error);
3446 				if (!(fd_cb->flags & FLOW_DIVERT_FLOW_IS_TRANSPARENT) || IN6_IS_ADDR_UNSPECIFIED(&(satosin6(to)->sin6_addr))) {
3447 					error = 0;
3448 				} else {
3449 					goto done;
3450 				}
3451 			}
3452 			if (ifp != NULL) {
3453 				inp->in6p_last_outifp = ifp;
3454 				ifnet_release(ifp);
3455 			}
3456 		} else if (to->sa_family == AF_INET) {
3457 			inp->inp_vflag |= INP_IPV4;
3458 			inp->inp_vflag &= ~INP_IPV6;
3459 			fd_cb->local_endpoint.sin.sin_len = sizeof(struct sockaddr_in);
3460 			fd_cb->local_endpoint.sin.sin_family = AF_INET;
3461 			fd_cb->local_endpoint.sin.sin_port = inp->inp_lport;
3462 			error = in_pcbladdr(inp, to, &(fd_cb->local_endpoint.sin.sin_addr), IFSCOPE_NONE, &ifp, 0);
3463 			if (error) {
3464 				FDLOG(LOG_WARNING, fd_cb, "failed to get a local IPv4 address: %d", error);
3465 				if (!(fd_cb->flags & FLOW_DIVERT_FLOW_IS_TRANSPARENT) || satosin(to)->sin_addr.s_addr == INADDR_ANY) {
3466 					error = 0;
3467 				} else {
3468 					goto done;
3469 				}
3470 			}
3471 			if (ifp != NULL) {
3472 				inp->inp_last_outifp = ifp;
3473 				ifnet_release(ifp);
3474 			}
3475 		} else {
3476 			FDLOG(LOG_WARNING, fd_cb, "target address has an unsupported family: %d", to->sa_family);
3477 		}
3478 
3479 		error = flow_divert_check_no_cellular(fd_cb) ||
3480 		    flow_divert_check_no_expensive(fd_cb) ||
3481 		    flow_divert_check_no_constrained(fd_cb);
3482 		if (error) {
3483 			goto done;
3484 		}
3485 
3486 		if (SOCK_TYPE(so) == SOCK_STREAM || /* TCP or */
3487 		    !implicit || /* connect() was called or */
3488 		    ((inp->inp_vflag & INP_IPV6) && !IN6_IS_ADDR_UNSPECIFIED(&inp->in6p_laddr)) || /* local address is not un-specified */
3489 		    ((inp->inp_vflag & INP_IPV4) && inp->inp_laddr.s_addr != INADDR_ANY)) {
3490 			fd_cb->flags |= FLOW_DIVERT_SHOULD_SET_LOCAL_ADDR;
3491 		}
3492 
3493 		error = flow_divert_create_connect_packet(fd_cb, to, so, p, &connect_packet);
3494 		if (error) {
3495 			goto done;
3496 		}
3497 
3498 		if (!implicit || SOCK_TYPE(so) == SOCK_STREAM) {
3499 			flow_divert_set_remote_endpoint(fd_cb, to);
3500 			flow_divert_set_local_endpoint(fd_cb, &(fd_cb->local_endpoint.sa));
3501 		}
3502 
3503 		if (implicit) {
3504 			fd_cb->flags |= FLOW_DIVERT_IMPLICIT_CONNECT;
3505 		}
3506 
3507 		if (so->so_flags1 & SOF1_PRECONNECT_DATA) {
3508 			FDLOG0(LOG_INFO, fd_cb, "Delaying sending the connect packet until send or receive");
3509 			do_send = 0;
3510 		}
3511 
3512 		fd_cb->connect_packet = connect_packet;
3513 		connect_packet = NULL;
3514 	} else {
3515 		FDLOG0(LOG_INFO, fd_cb, "Sending saved connect packet");
3516 	}
3517 
3518 	if (do_send) {
3519 		error = flow_divert_send_connect_packet(fd_cb);
3520 		if (error) {
3521 			goto done;
3522 		}
3523 
3524 		fd_cb->flags |= FLOW_DIVERT_CONNECT_STARTED;
3525 	}
3526 
3527 	if (SOCK_TYPE(so) == SOCK_DGRAM && !(fd_cb->flags & FLOW_DIVERT_HAS_TOKEN)) {
3528 		soisconnected(so);
3529 	} else {
3530 		soisconnecting(so);
3531 	}
3532 
3533 done:
3534 	return error;
3535 }
3536 
3537 errno_t
flow_divert_connect_out(struct socket * so,struct sockaddr * to,proc_t p)3538 flow_divert_connect_out(struct socket *so, struct sockaddr *to, proc_t p)
3539 {
3540 #if CONTENT_FILTER
3541 	if (SOCK_TYPE(so) == SOCK_STREAM && !(so->so_flags & SOF_CONTENT_FILTER)) {
3542 		int error = cfil_sock_attach(so, NULL, to, CFS_CONNECTION_DIR_OUT);
3543 		if (error != 0) {
3544 			struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
3545 			FDLOG(LOG_ERR, fd_cb, "Failed to attach cfil: %d", error);
3546 			return error;
3547 		}
3548 	}
3549 #endif /* CONTENT_FILTER */
3550 
3551 	return flow_divert_connect_out_internal(so, to, p, false);
3552 }
3553 
3554 static int
flow_divert_connectx_out_common(struct socket * so,struct sockaddr * dst,struct proc * p,sae_connid_t * pcid,struct uio * auio,user_ssize_t * bytes_written)3555 flow_divert_connectx_out_common(struct socket *so, struct sockaddr *dst,
3556     struct proc *p, sae_connid_t *pcid, struct uio *auio, user_ssize_t *bytes_written)
3557 {
3558 	struct inpcb *inp = sotoinpcb(so);
3559 	int error;
3560 
3561 	if (inp == NULL) {
3562 		return EINVAL;
3563 	}
3564 
3565 	VERIFY(dst != NULL);
3566 
3567 #if CONTENT_FILTER && NECP
3568 	struct flow_divert_pcb *fd_cb = so->so_fd_pcb;
3569 	if (fd_cb != NULL && (fd_cb->flags & FLOW_DIVERT_HAS_TOKEN) &&
3570 	    SOCK_TYPE(so) == SOCK_STREAM && !(so->so_flags & SOF_CONTENT_FILTER)) {
3571 		inp_update_necp_policy(sotoinpcb(so), NULL, dst, 0);
3572 	}
3573 #endif /* CONTENT_FILTER */
3574 
3575 	error = flow_divert_connect_out(so, dst, p);
3576 
3577 	if (error != 0) {
3578 		return error;
3579 	}
3580 
3581 	/* if there is data, send it */
3582 	if (auio != NULL) {
3583 		user_ssize_t datalen = 0;
3584 
3585 		socket_unlock(so, 0);
3586 
3587 		VERIFY(bytes_written != NULL);
3588 
3589 		datalen = uio_resid(auio);
3590 		error = so->so_proto->pr_usrreqs->pru_sosend(so, NULL, (uio_t)auio, NULL, NULL, 0);
3591 		socket_lock(so, 0);
3592 
3593 		if (error == 0 || error == EWOULDBLOCK) {
3594 			*bytes_written = datalen - uio_resid(auio);
3595 		}
3596 
3597 		/*
3598 		 * sosend returns EWOULDBLOCK if it's a non-blocking
3599 		 * socket or a timeout occured (this allows to return
3600 		 * the amount of queued data through sendit()).
3601 		 *
3602 		 * However, connectx() returns EINPROGRESS in case of a
3603 		 * blocking socket. So we change the return value here.
3604 		 */
3605 		if (error == EWOULDBLOCK) {
3606 			error = EINPROGRESS;
3607 		}
3608 	}
3609 
3610 	if (error == 0 && pcid != NULL) {
3611 		*pcid = 1;      /* there is only 1 connection for a TCP */
3612 	}
3613 
3614 	return error;
3615 }
3616 
3617 static int
flow_divert_connectx_out(struct socket * so,struct sockaddr * src __unused,struct sockaddr * dst,struct proc * p,uint32_t ifscope __unused,sae_associd_t aid __unused,sae_connid_t * pcid,uint32_t flags __unused,void * arg __unused,uint32_t arglen __unused,struct uio * uio,user_ssize_t * bytes_written)3618 flow_divert_connectx_out(struct socket *so, struct sockaddr *src __unused,
3619     struct sockaddr *dst, struct proc *p, uint32_t ifscope __unused,
3620     sae_associd_t aid __unused, sae_connid_t *pcid, uint32_t flags __unused, void *arg __unused,
3621     uint32_t arglen __unused, struct uio *uio, user_ssize_t *bytes_written)
3622 {
3623 	return flow_divert_connectx_out_common(so, dst, p, pcid, uio, bytes_written);
3624 }
3625 
3626 static int
flow_divert_connectx6_out(struct socket * so,struct sockaddr * src __unused,struct sockaddr * dst,struct proc * p,uint32_t ifscope __unused,sae_associd_t aid __unused,sae_connid_t * pcid,uint32_t flags __unused,void * arg __unused,uint32_t arglen __unused,struct uio * uio,user_ssize_t * bytes_written)3627 flow_divert_connectx6_out(struct socket *so, struct sockaddr *src __unused,
3628     struct sockaddr *dst, struct proc *p, uint32_t ifscope __unused,
3629     sae_associd_t aid __unused, sae_connid_t *pcid, uint32_t flags __unused, void *arg __unused,
3630     uint32_t arglen __unused, struct uio *uio, user_ssize_t *bytes_written)
3631 {
3632 	return flow_divert_connectx_out_common(so, dst, p, pcid, uio, bytes_written);
3633 }
3634 
3635 static errno_t
flow_divert_data_out(struct socket * so,int flags,mbuf_t data,struct sockaddr * to,mbuf_t control,struct proc * p)3636 flow_divert_data_out(struct socket *so, int flags, mbuf_t data, struct sockaddr *to, mbuf_t control, struct proc *p)
3637 {
3638 	struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
3639 	int                                             error   = 0;
3640 	struct inpcb *inp;
3641 #if CONTENT_FILTER
3642 	struct m_tag *cfil_tag = NULL;
3643 #endif
3644 
3645 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3646 
3647 	inp = sotoinpcb(so);
3648 	if (inp == NULL || inp->inp_state == INPCB_STATE_DEAD) {
3649 		error = ECONNRESET;
3650 		goto done;
3651 	}
3652 
3653 	if (control && mbuf_len(control) > 0) {
3654 		error = EINVAL;
3655 		goto done;
3656 	}
3657 
3658 	if (flags & MSG_OOB) {
3659 		error = EINVAL;
3660 		goto done; /* We don't support OOB data */
3661 	}
3662 
3663 	if ((fd_cb->flags & FLOW_DIVERT_TUNNEL_WR_CLOSED) && SOCK_TYPE(so) == SOCK_DGRAM) {
3664 		/* The provider considers this datagram flow to be closed, so no data can be sent */
3665 		FDLOG0(LOG_INFO, fd_cb, "provider is no longer accepting writes, cannot send data");
3666 		error = EHOSTUNREACH;
3667 		goto done;
3668 	}
3669 
3670 #if CONTENT_FILTER
3671 	/*
3672 	 * If the socket is subject to a UDP Content Filter and no remote address is passed in,
3673 	 * retrieve the CFIL saved remote address from the mbuf and use it.
3674 	 */
3675 	if (to == NULL && CFIL_DGRAM_FILTERED(so)) {
3676 		struct sockaddr *cfil_faddr = NULL;
3677 		cfil_tag = cfil_dgram_get_socket_state(data, NULL, NULL, &cfil_faddr, NULL);
3678 		if (cfil_tag) {
3679 			to = (struct sockaddr *)(void *)cfil_faddr;
3680 		}
3681 		FDLOG(LOG_INFO, fd_cb, "Using remote address from CFIL saved state: %p", to);
3682 	}
3683 #endif
3684 
3685 	/* Implicit connect */
3686 	if (!(fd_cb->flags & FLOW_DIVERT_CONNECT_STARTED)) {
3687 		FDLOG0(LOG_INFO, fd_cb, "implicit connect");
3688 
3689 		error = flow_divert_connect_out_internal(so, to, p, true);
3690 		if (error) {
3691 			goto done;
3692 		}
3693 	} else {
3694 		error = flow_divert_check_no_cellular(fd_cb) ||
3695 		    flow_divert_check_no_expensive(fd_cb) ||
3696 		    flow_divert_check_no_constrained(fd_cb);
3697 		if (error) {
3698 			goto done;
3699 		}
3700 	}
3701 
3702 	FDLOG(LOG_DEBUG, fd_cb, "app wrote %lu bytes", mbuf_pkthdr_len(data));
3703 
3704 	fd_cb->bytes_written_by_app += mbuf_pkthdr_len(data);
3705 	error = flow_divert_send_app_data(fd_cb, data, to);
3706 
3707 	data = NULL;
3708 
3709 	if (error) {
3710 		goto done;
3711 	}
3712 
3713 	if (flags & PRUS_EOF) {
3714 		flow_divert_shutdown(so);
3715 	}
3716 
3717 done:
3718 	if (data) {
3719 		mbuf_freem(data);
3720 	}
3721 	if (control) {
3722 		mbuf_free(control);
3723 	}
3724 #if CONTENT_FILTER
3725 	if (cfil_tag) {
3726 		m_tag_free(cfil_tag);
3727 	}
3728 #endif
3729 
3730 	return error;
3731 }
3732 
3733 static int
flow_divert_preconnect(struct socket * so)3734 flow_divert_preconnect(struct socket *so)
3735 {
3736 	int error = 0;
3737 	struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
3738 
3739 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
3740 
3741 	if (!(fd_cb->flags & FLOW_DIVERT_CONNECT_STARTED)) {
3742 		FDLOG0(LOG_INFO, fd_cb, "Pre-connect read: sending saved connect packet");
3743 		error = flow_divert_send_connect_packet(so->so_fd_pcb);
3744 		if (error) {
3745 			return error;
3746 		}
3747 
3748 		fd_cb->flags |= FLOW_DIVERT_CONNECT_STARTED;
3749 	}
3750 
3751 	soclearfastopen(so);
3752 
3753 	return error;
3754 }
3755 
3756 static void
flow_divert_set_protosw(struct socket * so)3757 flow_divert_set_protosw(struct socket *so)
3758 {
3759 	if (SOCK_DOM(so) == PF_INET) {
3760 		so->so_proto = &g_flow_divert_in_protosw;
3761 	} else {
3762 		so->so_proto = (struct protosw *)&g_flow_divert_in6_protosw;
3763 	}
3764 }
3765 
3766 static void
flow_divert_set_udp_protosw(struct socket * so)3767 flow_divert_set_udp_protosw(struct socket *so)
3768 {
3769 	if (SOCK_DOM(so) == PF_INET) {
3770 		so->so_proto = &g_flow_divert_in_udp_protosw;
3771 	} else {
3772 		so->so_proto = (struct protosw *)&g_flow_divert_in6_udp_protosw;
3773 	}
3774 }
3775 
3776 errno_t
flow_divert_implicit_data_out(struct socket * so,int flags,mbuf_t data,struct sockaddr * to,mbuf_t control,struct proc * p)3777 flow_divert_implicit_data_out(struct socket *so, int flags, mbuf_t data, struct sockaddr *to, mbuf_t control, struct proc *p)
3778 {
3779 	struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
3780 	struct inpcb *inp;
3781 	int error = 0;
3782 
3783 	inp = sotoinpcb(so);
3784 	if (inp == NULL) {
3785 		return EINVAL;
3786 	}
3787 
3788 	if (fd_cb == NULL) {
3789 		error = flow_divert_pcb_init(so);
3790 		fd_cb  = so->so_fd_pcb;
3791 		if (error != 0 || fd_cb == NULL) {
3792 			goto done;
3793 		}
3794 	}
3795 	return flow_divert_data_out(so, flags, data, to, control, p);
3796 
3797 done:
3798 	if (data) {
3799 		mbuf_freem(data);
3800 	}
3801 	if (control) {
3802 		mbuf_free(control);
3803 	}
3804 
3805 	return error;
3806 }
3807 
3808 static errno_t
flow_divert_pcb_init_internal(struct socket * so,uint32_t ctl_unit,uint32_t aggregate_unit)3809 flow_divert_pcb_init_internal(struct socket *so, uint32_t ctl_unit, uint32_t aggregate_unit)
3810 {
3811 	errno_t error = 0;
3812 	struct flow_divert_pcb *fd_cb;
3813 	uint32_t agg_unit = aggregate_unit;
3814 	bool is_aggregate = false;
3815 	uint32_t group_unit = flow_divert_derive_kernel_control_unit(ctl_unit, &agg_unit, &is_aggregate);
3816 
3817 	if (group_unit == 0) {
3818 		return EINVAL;
3819 	}
3820 
3821 	if (so->so_flags & SOF_FLOW_DIVERT) {
3822 		return EALREADY;
3823 	}
3824 
3825 	fd_cb = flow_divert_pcb_create(so);
3826 	if (fd_cb != NULL) {
3827 		so->so_fd_pcb = fd_cb;
3828 		so->so_flags |= SOF_FLOW_DIVERT;
3829 		fd_cb->control_group_unit = group_unit;
3830 		fd_cb->policy_control_unit = ctl_unit;
3831 		fd_cb->aggregate_unit = agg_unit;
3832 		if (is_aggregate) {
3833 			fd_cb->flags |= FLOW_DIVERT_FLOW_IS_TRANSPARENT;
3834 		} else {
3835 			fd_cb->flags &= ~FLOW_DIVERT_FLOW_IS_TRANSPARENT;
3836 		}
3837 
3838 		error = flow_divert_pcb_insert(fd_cb, group_unit);
3839 		if (error) {
3840 			FDLOG(LOG_ERR, fd_cb, "pcb insert failed: %d", error);
3841 			so->so_fd_pcb = NULL;
3842 			so->so_flags &= ~SOF_FLOW_DIVERT;
3843 			FDRELEASE(fd_cb);
3844 		} else {
3845 			if (SOCK_TYPE(so) == SOCK_STREAM) {
3846 				flow_divert_set_protosw(so);
3847 			} else if (SOCK_TYPE(so) == SOCK_DGRAM) {
3848 				flow_divert_set_udp_protosw(so);
3849 			}
3850 
3851 			FDLOG0(LOG_INFO, fd_cb, "Created");
3852 		}
3853 	} else {
3854 		error = ENOMEM;
3855 	}
3856 
3857 	return error;
3858 }
3859 
3860 errno_t
flow_divert_pcb_init(struct socket * so)3861 flow_divert_pcb_init(struct socket *so)
3862 {
3863 	struct inpcb *inp = sotoinpcb(so);
3864 	uint32_t aggregate_units = 0;
3865 	uint32_t ctl_unit = necp_socket_get_flow_divert_control_unit(inp, &aggregate_units);
3866 	return flow_divert_pcb_init_internal(so, ctl_unit, aggregate_units);
3867 }
3868 
3869 errno_t
flow_divert_token_set(struct socket * so,struct sockopt * sopt)3870 flow_divert_token_set(struct socket *so, struct sockopt *sopt)
3871 {
3872 	uint32_t ctl_unit = 0;
3873 	uint32_t key_unit = 0;
3874 	uint32_t aggregate_unit = 0;
3875 	int error = 0;
3876 	int hmac_error = 0;
3877 	mbuf_t token = NULL;
3878 
3879 	if (so->so_flags & SOF_FLOW_DIVERT) {
3880 		error = EALREADY;
3881 		goto done;
3882 	}
3883 
3884 	if (g_init_result) {
3885 		FDLOG(LOG_ERR, &nil_pcb, "flow_divert_init failed (%d), cannot use flow divert", g_init_result);
3886 		error = ENOPROTOOPT;
3887 		goto done;
3888 	}
3889 
3890 	if ((SOCK_TYPE(so) != SOCK_STREAM && SOCK_TYPE(so) != SOCK_DGRAM) ||
3891 	    (SOCK_PROTO(so) != IPPROTO_TCP && SOCK_PROTO(so) != IPPROTO_UDP) ||
3892 	    (SOCK_DOM(so) != PF_INET && SOCK_DOM(so) != PF_INET6)) {
3893 		error = EINVAL;
3894 		goto done;
3895 	} else {
3896 		if (SOCK_TYPE(so) == SOCK_STREAM && SOCK_PROTO(so) == IPPROTO_TCP) {
3897 			struct tcpcb *tp = sototcpcb(so);
3898 			if (tp == NULL || tp->t_state != TCPS_CLOSED) {
3899 				error = EINVAL;
3900 				goto done;
3901 			}
3902 		}
3903 	}
3904 
3905 	error = soopt_getm(sopt, &token);
3906 	if (error) {
3907 		token = NULL;
3908 		goto done;
3909 	}
3910 
3911 	error = soopt_mcopyin(sopt, token);
3912 	if (error) {
3913 		token = NULL;
3914 		goto done;
3915 	}
3916 
3917 	error = flow_divert_packet_get_tlv(token, 0, FLOW_DIVERT_TLV_KEY_UNIT, sizeof(key_unit), (void *)&key_unit, NULL);
3918 	if (!error) {
3919 		key_unit = ntohl(key_unit);
3920 		if (key_unit >= GROUP_COUNT_MAX) {
3921 			key_unit = 0;
3922 		}
3923 	} else if (error != ENOENT) {
3924 		FDLOG(LOG_ERR, &nil_pcb, "Failed to get the key unit from the token: %d", error);
3925 		goto done;
3926 	} else {
3927 		key_unit = 0;
3928 	}
3929 
3930 	error = flow_divert_packet_get_tlv(token, 0, FLOW_DIVERT_TLV_CTL_UNIT, sizeof(ctl_unit), (void *)&ctl_unit, NULL);
3931 	if (error) {
3932 		FDLOG(LOG_ERR, &nil_pcb, "Failed to get the control socket unit from the token: %d", error);
3933 		goto done;
3934 	}
3935 
3936 	error = flow_divert_packet_get_tlv(token, 0, FLOW_DIVERT_TLV_AGGREGATE_UNIT, sizeof(aggregate_unit), (void *)&aggregate_unit, NULL);
3937 	if (error && error != ENOENT) {
3938 		FDLOG(LOG_ERR, &nil_pcb, "Failed to get the aggregate unit from the token: %d", error);
3939 		goto done;
3940 	}
3941 
3942 	/* A valid kernel control unit is required */
3943 	ctl_unit = ntohl(ctl_unit);
3944 	aggregate_unit = ntohl(aggregate_unit);
3945 
3946 	if (ctl_unit > 0 && ctl_unit < GROUP_COUNT_MAX) {
3947 		socket_unlock(so, 0);
3948 		hmac_error = flow_divert_packet_verify_hmac(token, (key_unit != 0 ? key_unit : ctl_unit));
3949 		socket_lock(so, 0);
3950 
3951 		if (hmac_error && hmac_error != ENOENT) {
3952 			FDLOG(LOG_ERR, &nil_pcb, "HMAC verfication failed: %d", hmac_error);
3953 			error = hmac_error;
3954 			goto done;
3955 		}
3956 	}
3957 
3958 	error = flow_divert_pcb_init_internal(so, ctl_unit, aggregate_unit);
3959 	if (error == 0) {
3960 		struct flow_divert_pcb *fd_cb = so->so_fd_pcb;
3961 		int log_level = LOG_NOTICE;
3962 
3963 		error = flow_divert_packet_get_tlv(token, 0, FLOW_DIVERT_TLV_LOG_LEVEL, sizeof(log_level), &log_level, NULL);
3964 		if (error == 0) {
3965 			fd_cb->log_level = (uint8_t)log_level;
3966 		}
3967 		error = 0;
3968 
3969 		fd_cb->connect_token = token;
3970 		token = NULL;
3971 
3972 		fd_cb->flags |= FLOW_DIVERT_HAS_TOKEN;
3973 	}
3974 
3975 	if (hmac_error == 0) {
3976 		struct flow_divert_pcb *fd_cb = so->so_fd_pcb;
3977 		if (fd_cb != NULL) {
3978 			fd_cb->flags |= FLOW_DIVERT_HAS_HMAC;
3979 		}
3980 	}
3981 
3982 done:
3983 	if (token != NULL) {
3984 		mbuf_freem(token);
3985 	}
3986 
3987 	return error;
3988 }
3989 
3990 errno_t
flow_divert_token_get(struct socket * so,struct sockopt * sopt)3991 flow_divert_token_get(struct socket *so, struct sockopt *sopt)
3992 {
3993 	uint32_t                                        ctl_unit;
3994 	int                                                     error                                           = 0;
3995 	uint8_t                                         hmac[SHA_DIGEST_LENGTH];
3996 	struct flow_divert_pcb          *fd_cb                                          = so->so_fd_pcb;
3997 	mbuf_t                                          token                                           = NULL;
3998 	struct flow_divert_group        *control_group                          = NULL;
3999 
4000 	if (!(so->so_flags & SOF_FLOW_DIVERT)) {
4001 		error = EINVAL;
4002 		goto done;
4003 	}
4004 
4005 	VERIFY((so->so_flags & SOF_FLOW_DIVERT) && so->so_fd_pcb != NULL);
4006 
4007 	if (fd_cb->group == NULL) {
4008 		error = EINVAL;
4009 		goto done;
4010 	}
4011 
4012 	error = mbuf_gethdr(MBUF_DONTWAIT, MBUF_TYPE_HEADER, &token);
4013 	if (error) {
4014 		FDLOG(LOG_ERR, fd_cb, "failed to allocate the header mbuf: %d", error);
4015 		goto done;
4016 	}
4017 
4018 	ctl_unit = htonl(fd_cb->group->ctl_unit);
4019 
4020 	error = flow_divert_packet_append_tlv(token, FLOW_DIVERT_TLV_CTL_UNIT, sizeof(ctl_unit), &ctl_unit);
4021 	if (error) {
4022 		goto done;
4023 	}
4024 
4025 	error = flow_divert_packet_append_tlv(token, FLOW_DIVERT_TLV_FLOW_ID, sizeof(fd_cb->hash), &fd_cb->hash);
4026 	if (error) {
4027 		goto done;
4028 	}
4029 
4030 	if (fd_cb->app_data != NULL) {
4031 		error = flow_divert_packet_append_tlv(token, FLOW_DIVERT_TLV_APP_DATA, (uint32_t)fd_cb->app_data_length, fd_cb->app_data);
4032 		if (error) {
4033 			goto done;
4034 		}
4035 	}
4036 
4037 	socket_unlock(so, 0);
4038 	lck_rw_lock_shared(&g_flow_divert_group_lck);
4039 
4040 	if (g_flow_divert_groups != NULL && g_active_group_count > 0 &&
4041 	    fd_cb->control_group_unit > 0 && fd_cb->control_group_unit < GROUP_COUNT_MAX) {
4042 		control_group = g_flow_divert_groups[fd_cb->control_group_unit];
4043 	}
4044 
4045 	if (control_group != NULL) {
4046 		lck_rw_lock_shared(&control_group->lck);
4047 		ctl_unit = htonl(control_group->ctl_unit);
4048 		error = flow_divert_packet_append_tlv(token, FLOW_DIVERT_TLV_KEY_UNIT, sizeof(ctl_unit), &ctl_unit);
4049 		if (!error) {
4050 			error = flow_divert_packet_compute_hmac(token, control_group, hmac);
4051 		}
4052 		lck_rw_done(&control_group->lck);
4053 	} else {
4054 		error = ENOPROTOOPT;
4055 	}
4056 
4057 	lck_rw_done(&g_flow_divert_group_lck);
4058 	socket_lock(so, 0);
4059 
4060 	if (error) {
4061 		goto done;
4062 	}
4063 
4064 	error = flow_divert_packet_append_tlv(token, FLOW_DIVERT_TLV_HMAC, sizeof(hmac), hmac);
4065 	if (error) {
4066 		goto done;
4067 	}
4068 
4069 	if (sopt->sopt_val == USER_ADDR_NULL) {
4070 		/* If the caller passed NULL to getsockopt, just set the size of the token and return */
4071 		sopt->sopt_valsize = mbuf_pkthdr_len(token);
4072 		goto done;
4073 	}
4074 
4075 	error = soopt_mcopyout(sopt, token);
4076 	if (error) {
4077 		token = NULL;   /* For some reason, soopt_mcopyout() frees the mbuf if it fails */
4078 		goto done;
4079 	}
4080 
4081 done:
4082 	if (token != NULL) {
4083 		mbuf_freem(token);
4084 	}
4085 
4086 	return error;
4087 }
4088 
4089 static errno_t
flow_divert_kctl_connect(kern_ctl_ref kctlref __unused,struct sockaddr_ctl * sac,void ** unitinfo)4090 flow_divert_kctl_connect(kern_ctl_ref kctlref __unused, struct sockaddr_ctl *sac, void **unitinfo)
4091 {
4092 	struct flow_divert_group        *new_group      = NULL;
4093 	int                             error           = 0;
4094 
4095 	if (sac->sc_unit >= GROUP_COUNT_MAX) {
4096 		error = EINVAL;
4097 		goto done;
4098 	}
4099 
4100 	*unitinfo = NULL;
4101 
4102 	new_group = zalloc_flags(flow_divert_group_zone, Z_WAITOK | Z_ZERO);
4103 	lck_rw_init(&new_group->lck, &flow_divert_mtx_grp, &flow_divert_mtx_attr);
4104 	RB_INIT(&new_group->pcb_tree);
4105 	new_group->ctl_unit = sac->sc_unit;
4106 	MBUFQ_INIT(&new_group->send_queue);
4107 	new_group->signing_id_trie.root = NULL_TRIE_IDX;
4108 
4109 	lck_rw_lock_exclusive(&g_flow_divert_group_lck);
4110 
4111 	if (g_flow_divert_groups == NULL) {
4112 		MALLOC(g_flow_divert_groups,
4113 		    struct flow_divert_group **,
4114 		    GROUP_COUNT_MAX * sizeof(struct flow_divert_group *),
4115 		    M_TEMP,
4116 		    M_WAITOK | M_ZERO);
4117 	}
4118 
4119 	if (g_flow_divert_groups == NULL) {
4120 		error = ENOBUFS;
4121 	} else if (g_flow_divert_groups[sac->sc_unit] != NULL) {
4122 		error = EALREADY;
4123 	} else {
4124 		g_flow_divert_groups[sac->sc_unit] = new_group;
4125 		g_active_group_count++;
4126 	}
4127 
4128 	lck_rw_done(&g_flow_divert_group_lck);
4129 
4130 done:
4131 	if (error == 0) {
4132 		*unitinfo = new_group;
4133 	} else if (new_group != NULL) {
4134 		zfree(flow_divert_group_zone, new_group);
4135 	}
4136 	return error;
4137 }
4138 
4139 static errno_t
flow_divert_kctl_disconnect(kern_ctl_ref kctlref __unused,uint32_t unit,void * unitinfo)4140 flow_divert_kctl_disconnect(kern_ctl_ref kctlref __unused, uint32_t unit, void *unitinfo)
4141 {
4142 	struct flow_divert_group        *group  = NULL;
4143 	errno_t                                         error   = 0;
4144 
4145 	if (unit >= GROUP_COUNT_MAX) {
4146 		return EINVAL;
4147 	}
4148 
4149 	if (unitinfo == NULL) {
4150 		return 0;
4151 	}
4152 
4153 	FDLOG(LOG_INFO, &nil_pcb, "disconnecting group %d", unit);
4154 
4155 	lck_rw_lock_exclusive(&g_flow_divert_group_lck);
4156 
4157 	if (g_flow_divert_groups == NULL || g_active_group_count == 0) {
4158 		panic("flow divert group %u is disconnecting, but no groups are active (groups = %p, active count = %u", unit,
4159 		    g_flow_divert_groups, g_active_group_count);
4160 	}
4161 
4162 	group = g_flow_divert_groups[unit];
4163 
4164 	if (group != (struct flow_divert_group *)unitinfo) {
4165 		panic("group with unit %d (%p) != unit info (%p)", unit, group, unitinfo);
4166 	}
4167 
4168 	g_flow_divert_groups[unit] = NULL;
4169 	g_active_group_count--;
4170 
4171 	if (g_active_group_count == 0) {
4172 		FREE(g_flow_divert_groups, M_TEMP);
4173 		g_flow_divert_groups = NULL;
4174 	}
4175 
4176 	lck_rw_done(&g_flow_divert_group_lck);
4177 
4178 	if (group != NULL) {
4179 		flow_divert_close_all(group);
4180 
4181 		lck_rw_lock_exclusive(&group->lck);
4182 
4183 		if (group->token_key != NULL) {
4184 			memset(group->token_key, 0, group->token_key_size);
4185 			kfree_data(group->token_key, group->token_key_size);
4186 			group->token_key = NULL;
4187 			group->token_key_size = 0;
4188 		}
4189 
4190 		/* Re-set the current trie */
4191 		if (group->signing_id_trie.memory != NULL) {
4192 			FREE(group->signing_id_trie.memory, M_TEMP);
4193 		}
4194 		memset(&group->signing_id_trie, 0, sizeof(group->signing_id_trie));
4195 		group->signing_id_trie.root = NULL_TRIE_IDX;
4196 
4197 		lck_rw_done(&group->lck);
4198 
4199 		zfree(flow_divert_group_zone, group);
4200 	} else {
4201 		error = EINVAL;
4202 	}
4203 
4204 	return error;
4205 }
4206 
4207 static errno_t
flow_divert_kctl_send(kern_ctl_ref kctlref __unused,uint32_t unit __unused,void * unitinfo,mbuf_t m,int flags __unused)4208 flow_divert_kctl_send(kern_ctl_ref kctlref __unused, uint32_t unit __unused, void *unitinfo, mbuf_t m, int flags __unused)
4209 {
4210 	return flow_divert_input(m, (struct flow_divert_group *)unitinfo);
4211 }
4212 
4213 static void
flow_divert_kctl_rcvd(kern_ctl_ref kctlref __unused,uint32_t unit __unused,void * unitinfo,int flags __unused)4214 flow_divert_kctl_rcvd(kern_ctl_ref kctlref __unused, uint32_t unit __unused, void *unitinfo, int flags __unused)
4215 {
4216 	struct flow_divert_group        *group  = (struct flow_divert_group *)unitinfo;
4217 
4218 	if (!OSTestAndClear(GROUP_BIT_CTL_ENQUEUE_BLOCKED, &group->atomic_bits)) {
4219 		struct flow_divert_pcb                  *fd_cb;
4220 		SLIST_HEAD(, flow_divert_pcb)   tmp_list;
4221 
4222 		lck_rw_lock_shared(&g_flow_divert_group_lck);
4223 		lck_rw_lock_exclusive(&group->lck);
4224 
4225 		while (!MBUFQ_EMPTY(&group->send_queue)) {
4226 			mbuf_t next_packet;
4227 			FDLOG0(LOG_DEBUG, &nil_pcb, "trying ctl_enqueuembuf again");
4228 			next_packet = MBUFQ_FIRST(&group->send_queue);
4229 			int error = ctl_enqueuembuf(g_flow_divert_kctl_ref, group->ctl_unit, next_packet, CTL_DATA_EOR);
4230 			if (error) {
4231 				FDLOG(LOG_DEBUG, &nil_pcb, "ctl_enqueuembuf returned an error: %d", error);
4232 				OSTestAndSet(GROUP_BIT_CTL_ENQUEUE_BLOCKED, &group->atomic_bits);
4233 				lck_rw_done(&group->lck);
4234 				lck_rw_done(&g_flow_divert_group_lck);
4235 				return;
4236 			}
4237 			MBUFQ_DEQUEUE(&group->send_queue, next_packet);
4238 		}
4239 
4240 		SLIST_INIT(&tmp_list);
4241 
4242 		RB_FOREACH(fd_cb, fd_pcb_tree, &group->pcb_tree) {
4243 			FDRETAIN(fd_cb);
4244 			SLIST_INSERT_HEAD(&tmp_list, fd_cb, tmp_list_entry);
4245 		}
4246 
4247 		lck_rw_done(&group->lck);
4248 
4249 		SLIST_FOREACH(fd_cb, &tmp_list, tmp_list_entry) {
4250 			FDLOCK(fd_cb);
4251 			if (fd_cb->so != NULL) {
4252 				socket_lock(fd_cb->so, 0);
4253 				if (fd_cb->group != NULL) {
4254 					flow_divert_send_buffered_data(fd_cb, FALSE);
4255 				}
4256 				socket_unlock(fd_cb->so, 0);
4257 			}
4258 			FDUNLOCK(fd_cb);
4259 			FDRELEASE(fd_cb);
4260 		}
4261 
4262 		lck_rw_done(&g_flow_divert_group_lck);
4263 	}
4264 }
4265 
4266 static int
flow_divert_kctl_init(void)4267 flow_divert_kctl_init(void)
4268 {
4269 	struct kern_ctl_reg     ctl_reg;
4270 	int                     result;
4271 
4272 	memset(&ctl_reg, 0, sizeof(ctl_reg));
4273 
4274 	strlcpy(ctl_reg.ctl_name, FLOW_DIVERT_CONTROL_NAME, sizeof(ctl_reg.ctl_name));
4275 	ctl_reg.ctl_name[sizeof(ctl_reg.ctl_name) - 1] = '\0';
4276 	ctl_reg.ctl_flags = CTL_FLAG_PRIVILEGED | CTL_FLAG_REG_EXTENDED;
4277 	ctl_reg.ctl_sendsize = FD_CTL_SENDBUFF_SIZE;
4278 	ctl_reg.ctl_recvsize = FD_CTL_RCVBUFF_SIZE;
4279 
4280 	ctl_reg.ctl_connect = flow_divert_kctl_connect;
4281 	ctl_reg.ctl_disconnect = flow_divert_kctl_disconnect;
4282 	ctl_reg.ctl_send = flow_divert_kctl_send;
4283 	ctl_reg.ctl_rcvd = flow_divert_kctl_rcvd;
4284 
4285 	result = ctl_register(&ctl_reg, &g_flow_divert_kctl_ref);
4286 
4287 	if (result) {
4288 		FDLOG(LOG_ERR, &nil_pcb, "flow_divert_kctl_init - ctl_register failed: %d\n", result);
4289 		return result;
4290 	}
4291 
4292 	return 0;
4293 }
4294 
4295 void
flow_divert_init(void)4296 flow_divert_init(void)
4297 {
4298 	memset(&nil_pcb, 0, sizeof(nil_pcb));
4299 	nil_pcb.log_level = LOG_NOTICE;
4300 
4301 	g_tcp_protosw = pffindproto(AF_INET, IPPROTO_TCP, SOCK_STREAM);
4302 
4303 	VERIFY(g_tcp_protosw != NULL);
4304 
4305 	memcpy(&g_flow_divert_in_protosw, g_tcp_protosw, sizeof(g_flow_divert_in_protosw));
4306 	memcpy(&g_flow_divert_in_usrreqs, g_tcp_protosw->pr_usrreqs, sizeof(g_flow_divert_in_usrreqs));
4307 
4308 	g_flow_divert_in_usrreqs.pru_connect = flow_divert_connect_out;
4309 	g_flow_divert_in_usrreqs.pru_connectx = flow_divert_connectx_out;
4310 	g_flow_divert_in_usrreqs.pru_disconnect = flow_divert_close;
4311 	g_flow_divert_in_usrreqs.pru_disconnectx = flow_divert_disconnectx;
4312 	g_flow_divert_in_usrreqs.pru_rcvd = flow_divert_rcvd;
4313 	g_flow_divert_in_usrreqs.pru_send = flow_divert_data_out;
4314 	g_flow_divert_in_usrreqs.pru_shutdown = flow_divert_shutdown;
4315 	g_flow_divert_in_usrreqs.pru_preconnect = flow_divert_preconnect;
4316 
4317 	g_flow_divert_in_protosw.pr_usrreqs = &g_flow_divert_in_usrreqs;
4318 	g_flow_divert_in_protosw.pr_ctloutput = flow_divert_ctloutput;
4319 
4320 	/*
4321 	 * Socket filters shouldn't attach/detach to/from this protosw
4322 	 * since pr_protosw is to be used instead, which points to the
4323 	 * real protocol; if they do, it is a bug and we should panic.
4324 	 */
4325 	g_flow_divert_in_protosw.pr_filter_head.tqh_first =
4326 	    (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
4327 	g_flow_divert_in_protosw.pr_filter_head.tqh_last =
4328 	    (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
4329 
4330 	/* UDP */
4331 	g_udp_protosw = pffindproto(AF_INET, IPPROTO_UDP, SOCK_DGRAM);
4332 	VERIFY(g_udp_protosw != NULL);
4333 
4334 	memcpy(&g_flow_divert_in_udp_protosw, g_udp_protosw, sizeof(g_flow_divert_in_udp_protosw));
4335 	memcpy(&g_flow_divert_in_udp_usrreqs, g_udp_protosw->pr_usrreqs, sizeof(g_flow_divert_in_udp_usrreqs));
4336 
4337 	g_flow_divert_in_udp_usrreqs.pru_connect = flow_divert_connect_out;
4338 	g_flow_divert_in_udp_usrreqs.pru_connectx = flow_divert_connectx_out;
4339 	g_flow_divert_in_udp_usrreqs.pru_disconnect = flow_divert_close;
4340 	g_flow_divert_in_udp_usrreqs.pru_disconnectx = flow_divert_disconnectx;
4341 	g_flow_divert_in_udp_usrreqs.pru_rcvd = flow_divert_rcvd;
4342 	g_flow_divert_in_udp_usrreqs.pru_send = flow_divert_data_out;
4343 	g_flow_divert_in_udp_usrreqs.pru_shutdown = flow_divert_shutdown;
4344 	g_flow_divert_in_udp_usrreqs.pru_sosend_list = pru_sosend_list_notsupp;
4345 	g_flow_divert_in_udp_usrreqs.pru_soreceive_list = pru_soreceive_list_notsupp;
4346 	g_flow_divert_in_udp_usrreqs.pru_preconnect = flow_divert_preconnect;
4347 
4348 	g_flow_divert_in_udp_protosw.pr_usrreqs = &g_flow_divert_in_usrreqs;
4349 	g_flow_divert_in_udp_protosw.pr_ctloutput = flow_divert_ctloutput;
4350 
4351 	/*
4352 	 * Socket filters shouldn't attach/detach to/from this protosw
4353 	 * since pr_protosw is to be used instead, which points to the
4354 	 * real protocol; if they do, it is a bug and we should panic.
4355 	 */
4356 	g_flow_divert_in_udp_protosw.pr_filter_head.tqh_first =
4357 	    (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
4358 	g_flow_divert_in_udp_protosw.pr_filter_head.tqh_last =
4359 	    (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
4360 
4361 	g_tcp6_protosw = (struct ip6protosw *)pffindproto(AF_INET6, IPPROTO_TCP, SOCK_STREAM);
4362 
4363 	VERIFY(g_tcp6_protosw != NULL);
4364 
4365 	memcpy(&g_flow_divert_in6_protosw, g_tcp6_protosw, sizeof(g_flow_divert_in6_protosw));
4366 	memcpy(&g_flow_divert_in6_usrreqs, g_tcp6_protosw->pr_usrreqs, sizeof(g_flow_divert_in6_usrreqs));
4367 
4368 	g_flow_divert_in6_usrreqs.pru_connect = flow_divert_connect_out;
4369 	g_flow_divert_in6_usrreqs.pru_connectx = flow_divert_connectx6_out;
4370 	g_flow_divert_in6_usrreqs.pru_disconnect = flow_divert_close;
4371 	g_flow_divert_in6_usrreqs.pru_disconnectx = flow_divert_disconnectx;
4372 	g_flow_divert_in6_usrreqs.pru_rcvd = flow_divert_rcvd;
4373 	g_flow_divert_in6_usrreqs.pru_send = flow_divert_data_out;
4374 	g_flow_divert_in6_usrreqs.pru_shutdown = flow_divert_shutdown;
4375 	g_flow_divert_in6_usrreqs.pru_preconnect = flow_divert_preconnect;
4376 
4377 	g_flow_divert_in6_protosw.pr_usrreqs = &g_flow_divert_in6_usrreqs;
4378 	g_flow_divert_in6_protosw.pr_ctloutput = flow_divert_ctloutput;
4379 	/*
4380 	 * Socket filters shouldn't attach/detach to/from this protosw
4381 	 * since pr_protosw is to be used instead, which points to the
4382 	 * real protocol; if they do, it is a bug and we should panic.
4383 	 */
4384 	g_flow_divert_in6_protosw.pr_filter_head.tqh_first =
4385 	    (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
4386 	g_flow_divert_in6_protosw.pr_filter_head.tqh_last =
4387 	    (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
4388 
4389 	/* UDP6 */
4390 	g_udp6_protosw = (struct ip6protosw *)pffindproto(AF_INET6, IPPROTO_UDP, SOCK_DGRAM);
4391 
4392 	VERIFY(g_udp6_protosw != NULL);
4393 
4394 	memcpy(&g_flow_divert_in6_udp_protosw, g_udp6_protosw, sizeof(g_flow_divert_in6_udp_protosw));
4395 	memcpy(&g_flow_divert_in6_udp_usrreqs, g_udp6_protosw->pr_usrreqs, sizeof(g_flow_divert_in6_udp_usrreqs));
4396 
4397 	g_flow_divert_in6_udp_usrreqs.pru_connect = flow_divert_connect_out;
4398 	g_flow_divert_in6_udp_usrreqs.pru_connectx = flow_divert_connectx6_out;
4399 	g_flow_divert_in6_udp_usrreqs.pru_disconnect = flow_divert_close;
4400 	g_flow_divert_in6_udp_usrreqs.pru_disconnectx = flow_divert_disconnectx;
4401 	g_flow_divert_in6_udp_usrreqs.pru_rcvd = flow_divert_rcvd;
4402 	g_flow_divert_in6_udp_usrreqs.pru_send = flow_divert_data_out;
4403 	g_flow_divert_in6_udp_usrreqs.pru_shutdown = flow_divert_shutdown;
4404 	g_flow_divert_in6_udp_usrreqs.pru_sosend_list = pru_sosend_list_notsupp;
4405 	g_flow_divert_in6_udp_usrreqs.pru_soreceive_list = pru_soreceive_list_notsupp;
4406 	g_flow_divert_in6_udp_usrreqs.pru_preconnect = flow_divert_preconnect;
4407 
4408 	g_flow_divert_in6_udp_protosw.pr_usrreqs = &g_flow_divert_in6_udp_usrreqs;
4409 	g_flow_divert_in6_udp_protosw.pr_ctloutput = flow_divert_ctloutput;
4410 	/*
4411 	 * Socket filters shouldn't attach/detach to/from this protosw
4412 	 * since pr_protosw is to be used instead, which points to the
4413 	 * real protocol; if they do, it is a bug and we should panic.
4414 	 */
4415 	g_flow_divert_in6_udp_protosw.pr_filter_head.tqh_first =
4416 	    (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
4417 	g_flow_divert_in6_udp_protosw.pr_filter_head.tqh_last =
4418 	    (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
4419 
4420 	g_init_result = flow_divert_kctl_init();
4421 	if (g_init_result) {
4422 		goto done;
4423 	}
4424 
4425 done:
4426 	if (g_init_result != 0) {
4427 		if (g_flow_divert_kctl_ref != NULL) {
4428 			ctl_deregister(g_flow_divert_kctl_ref);
4429 			g_flow_divert_kctl_ref = NULL;
4430 		}
4431 	}
4432 }
4433