xref: /xnu-8796.121.2/bsd/skywalk/nexus/nexus_traffic_rule.c (revision c54f35ca767986246321eb901baf8f5ff7923f6a)
1 /*
2  * Copyright (c) 2022 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 #include <skywalk/os_skywalk_private.h>
29 #include <skywalk/nexus/netif/nx_netif.h>
30 #include <netinet/ip.h>
31 #include <netinet/ip6.h>
32 #include <netinet/tcp.h>
33 #include <netinet/udp.h>
34 #include <sys/sdt.h>
35 
36 /*
37  * Implementation of nexus traffic rules APIs.
38  */
39 
40 struct nxctl_traffic_rule_type;
41 struct nxctl_traffic_rule;
42 
43 /*
44  * These callbacks need to be implemented for each rule type.
45  */
46 
47 /* Validate user provided parameters. */
48 typedef int (nxctl_traffic_rule_validate_cb_t)(
49 	struct nxctl_traffic_rule_type *type,
50 	const char *ifname,
51 	struct ifnet_traffic_descriptor_common *td,
52 	struct ifnet_traffic_rule_action *ra);
53 /*
54  * Each rule type has its own global structure for storing rules.
55  * These callbacks access this global structure.
56  */
57 #define NTR_FIND_FLAG_EXACT 0x0001
58 typedef int (nxctl_traffic_rule_find_cb_t)(
59 	struct nxctl_traffic_rule_type *type,
60 	const char *ifname,
61 	struct ifnet_traffic_descriptor_common *td,
62 	uint32_t flags,
63 	struct nxctl_traffic_rule **ntrp);
64 
65 typedef int (nxctl_traffic_rule_find_by_uuid_cb_t)(
66 	struct nxctl_traffic_rule_type *type,
67 	uuid_t uuid,
68 	struct nxctl_traffic_rule **ntrp);
69 
70 typedef void (nxctl_traffic_rule_link_cb_t)(
71 	struct nxctl_traffic_rule *ntr);
72 
73 typedef void (nxctl_traffic_rule_unlink_cb_t)(
74 	struct nxctl_traffic_rule *ntr);
75 
76 /*
77  * Notifies lower layers of the addition/removal of a rule.
78  * This is called outside of nxctl_traffic_rule_lock to avoid potential
79  * locking issues.
80  */
81 #define NTR_NOTIFY_FLAG_ADD 0x0001
82 #define NTR_NOTIFY_FLAG_REMOVE 0x0002
83 typedef int (nxctl_traffic_rule_notify_cb_t)(
84 	struct nxctl_traffic_rule *ntr,
85 	uint32_t flags);
86 
87 /*
88  * Create/Destroy callbacks for a rule type.
89  */
90 typedef int (nxctl_traffic_rule_create_cb_t)(
91 	struct nxctl_traffic_rule_type *type,
92 	const char *ifname,
93 	struct ifnet_traffic_descriptor_common *td,
94 	struct ifnet_traffic_rule_action *ra,
95 	uint32_t flags,
96 	struct nxctl_traffic_rule **ntrp);
97 
98 typedef void (nxctl_traffic_rule_destroy_cb_t)(
99 	struct nxctl_traffic_rule *ntr);
100 
101 /*
102  * This is used for copying all rules for a type (including generic
103  * and type-specific info) to userspace.
104  */
105 typedef int (nxctl_traffic_rule_get_all_cb_t)(
106 	struct nxctl_traffic_rule_type *type,
107 	uint32_t size,
108 	uint32_t *count,
109 	user_addr_t uaddr);
110 
111 struct nxctl_traffic_rule_type {
112 	uint8_t ntrt_type;
113 	nxctl_traffic_rule_validate_cb_t *ntrt_validate;
114 	nxctl_traffic_rule_find_cb_t *ntrt_find;
115 	nxctl_traffic_rule_find_by_uuid_cb_t *ntrt_find_by_uuid;
116 	nxctl_traffic_rule_link_cb_t *ntrt_link;
117 	nxctl_traffic_rule_unlink_cb_t *ntrt_unlink;
118 	nxctl_traffic_rule_notify_cb_t *ntrt_notify;
119 	nxctl_traffic_rule_create_cb_t *ntrt_create;
120 	nxctl_traffic_rule_destroy_cb_t *ntrt_destroy;
121 	nxctl_traffic_rule_get_all_cb_t *ntrt_get_all;
122 	void *ntrt_storage;
123 };
124 
125 static nxctl_traffic_rule_validate_cb_t inet_traffic_rule_validate;
126 static nxctl_traffic_rule_find_cb_t inet_traffic_rule_find;
127 static nxctl_traffic_rule_find_by_uuid_cb_t inet_traffic_rule_find_by_uuid;
128 static nxctl_traffic_rule_link_cb_t inet_traffic_rule_link;
129 static nxctl_traffic_rule_unlink_cb_t inet_traffic_rule_unlink;
130 static nxctl_traffic_rule_notify_cb_t inet_traffic_rule_notify;
131 static nxctl_traffic_rule_create_cb_t inet_traffic_rule_create;
132 static nxctl_traffic_rule_destroy_cb_t inet_traffic_rule_destroy;
133 static nxctl_traffic_rule_get_all_cb_t inet_traffic_rule_get_all;
134 
135 static struct nxctl_traffic_rule_type nxctl_rule_types[] = {
136 	{
137 		.ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET,
138 		.ntrt_validate = inet_traffic_rule_validate,
139 		.ntrt_find = inet_traffic_rule_find,
140 		.ntrt_find_by_uuid = inet_traffic_rule_find_by_uuid,
141 		.ntrt_link = inet_traffic_rule_link,
142 		.ntrt_unlink = inet_traffic_rule_unlink,
143 		.ntrt_notify = inet_traffic_rule_notify,
144 		.ntrt_create = inet_traffic_rule_create,
145 		.ntrt_destroy = inet_traffic_rule_destroy,
146 		.ntrt_get_all = inet_traffic_rule_get_all,
147 	},
148 };
149 #define NRULETYPES \
150     (sizeof(nxctl_rule_types)/sizeof(struct nxctl_traffic_rule_type))
151 
152 /*
153  * Generic traffic rule.
154  * Contains fields common to all traffic rules.
155  */
156 #define NTR_FLAG_PERSIST 0x0001
157 #define NTR_FLAG_ON_NXCTL_LIST 0x0002
158 struct nxctl_traffic_rule {
159 	struct nxctl_traffic_rule_type *ntr_type;
160 	uint32_t ntr_flags;
161 	os_refcnt_t ntr_refcnt;
162 	uuid_t ntr_uuid;
163 	char ntr_procname[NTR_PROCNAME_SZ];
164 	char ntr_ifname[IFNAMSIZ];
165 	SLIST_ENTRY(nxctl_traffic_rule) ntr_storage_link;
166 };
167 
168 /*
169  * Inet-specific traffic rule.
170  */
171 struct nxctl_traffic_rule_inet {
172 	struct nxctl_traffic_rule ntri_common;
173 	SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
174 	struct ifnet_traffic_descriptor_inet ntri_td;
175 	struct ifnet_traffic_rule_action_steer ntri_ra;
176 };
177 
178 /*
179  * Currently supported tuple types.
180  */
181 static uint8_t nxctl_inet_traffic_rule_masks[] = {
182 	(IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
183 	IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
184 	IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
185 	IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
186 	IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
187 	IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT),
188 
189 	(IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
190 	IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
191 	IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
192 	IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT),
193 };
194 #define NINETRULEMASKS \
195     (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
196 
197 /* Per-interface lists of traffic rules */
198 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
199 struct nxctl_traffic_rule_inet_if {
200 	char rii_ifname[IFNAMSIZ];
201 	struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
202 	uint32_t rii_count;
203 	SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
204 };
205 
206 /* List of per-interface lists */
207 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
208 struct nxctl_traffic_rule_inet_storage {
209 	struct nxctl_traffic_rule_inet_if_head ris_if_list;
210 	uint32_t ris_count;
211 };
212 
213 /* Per-fd list kept at the nxctl */
214 SLIST_HEAD(nxctl_traffic_rule_head, nxctl_traffic_rule);
215 struct nxctl_traffic_rule_storage {
216 	struct nxctl_traffic_rule_head rs_list;
217 	uint32_t rs_count;
218 };
219 
220 static LCK_RW_DECLARE_ATTR(nxctl_traffic_rule_lock, &sk_lock_group, &sk_lock_attr);
221 #define NXTR_WLOCK() \
222     lck_rw_lock_exclusive(&nxctl_traffic_rule_lock)
223 #define NXTR_WUNLOCK() \
224     lck_rw_unlock_exclusive(&nxctl_traffic_rule_lock)
225 #define NXTR_RLOCK() \
226     lck_rw_lock_shared(&nxctl_traffic_rule_lock)
227 #define NXTR_RUNLOCK() \
228     lck_rw_unlock_shared(&nxctl_traffic_rule_lock)
229 
230 static struct nxctl_traffic_rule_type *find_traffic_rule_type(uint8_t type);
231 static void retain_traffic_rule(struct nxctl_traffic_rule *ntr);
232 static void release_traffic_rule(struct nxctl_traffic_rule *ntr);
233 static int remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
234     struct nxctl_traffic_rule **ntrp);
235 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
236     struct ifnet_ip_addr *a2);
237 static int notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags);
238 
239 #define NXCTL_TRAFFIC_RULE_TAG "com.apple.skywalk.nexus.traffic_rule"
240 static kern_allocation_name_t nxctl_traffic_rule_tag;
241 static struct nxctl_traffic_rule_type *inet_traffic_rule_type = NULL;
242 
243 /*
244  * If a interface attaches after rule(s) are added, this function is used
245  * retrieve the current rule count for that interface.
246  */
247 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)248 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
249 {
250 	struct nxctl_traffic_rule_inet_storage *rs;
251 	struct nxctl_traffic_rule_inet_if *rif;
252 	int err;
253 
254 	NXTR_RLOCK();
255 	rs = inet_traffic_rule_type->ntrt_storage;
256 	if (rs == NULL) {
257 		err = ENOENT;
258 		goto fail;
259 	}
260 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
261 		if (strcmp(rif->rii_ifname, ifname) == 0) {
262 			break;
263 		}
264 	}
265 	if (rif == NULL) {
266 		err = ENOENT;
267 		goto fail;
268 	}
269 	*count = rif->rii_count;
270 	NXTR_RUNLOCK();
271 	return 0;
272 fail:
273 	NXTR_RUNLOCK();
274 	return err;
275 }
276 
277 /*
278  * Used for finding the qset id associated with a traffic descriptor.
279  */
280 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)281 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
282     struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
283 {
284 	struct nxctl_traffic_rule_inet *ntri = NULL;
285 	int err;
286 
287 	NXTR_RLOCK();
288 	ASSERT(inet_traffic_rule_type != NULL);
289 	err = inet_traffic_rule_type->ntrt_find(inet_traffic_rule_type, ifname,
290 	    (struct ifnet_traffic_descriptor_common *)td, 0,
291 	    (struct nxctl_traffic_rule **)&ntri);
292 	if (err != 0) {
293 		SK_ERR("rule find failed: %d", err);
294 		goto fail;
295 	}
296 	*qset_id = ntri->ntri_ra.ras_qset_id;
297 	NXTR_RUNLOCK();
298 	return 0;
299 fail:
300 	NXTR_RUNLOCK();
301 	return err;
302 }
303 
304 /*
305  * Based on flow_pkt_classify().
306  * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
307  */
308 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)309 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
310 {
311 	union {
312 		volatile struct ip *_iph;
313 		volatile struct ip6_hdr *_ip6;
314 	} _l3;
315 	#define iph _l3._iph
316 	#define ip6 _l3._ip6
317 	union {
318 		volatile struct tcphdr *_tcph;
319 		volatile struct udphdr *_udph;
320 	} _l4;
321 	#define tcph _l4._tcph
322 	#define udph _l4._udph
323 	uint8_t *pkt_buf, *l3_hdr;
324 	uint16_t bdlen, bdlim, bdoff, cls_len;
325 	size_t pkt_len;
326 	uint8_t ipv, l3hlen = 0; /* IP header length */
327 	uint16_t l3tlen = 0;     /* total length of IP packet */
328 	uint8_t l4hlen = 0;      /* TCP/UDP header length */
329 	uint16_t ulen = 0;       /* user data length */
330 	int err;
331 
332 	ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
333 	pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
334 
335 	MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
336 	cls_len = bdlim - bdoff;
337 	cls_len -= pkt->pkt_l2_len;
338 	cls_len = (uint16_t)MIN(cls_len, pkt_len);
339 	VERIFY(pkt_len >= cls_len);
340 	if (cls_len == 0) {
341 		SK_ERR("cls_len == 0");
342 		err = EINVAL;
343 		goto fail;
344 	}
345 	l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
346 	iph = (volatile struct ip *)(void *)l3_hdr;
347 	ipv = iph->ip_v;
348 
349 	switch (ipv) {
350 	case 4:
351 		if (cls_len < sizeof(struct ip)) {
352 			SK_ERR("cls_len < sizeof(struct ip) (%d < %d)",
353 			    cls_len, sizeof(struct ip));
354 			err = EINVAL;
355 			goto fail;
356 		}
357 		l3hlen = (uint8_t)(iph->ip_hl << 2);
358 		if (l3hlen < sizeof(struct ip)) {
359 			SK_ERR("l3hlen < sizeof(struct ip) (%d < %d)",
360 			    l3hlen, sizeof(struct ip));
361 			err = EINVAL;
362 			goto fail;
363 		}
364 		if (cls_len < l3hlen) {
365 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
366 			err = EINVAL;
367 			goto fail;
368 		}
369 		l3tlen = ntohs(iph->ip_len);
370 		if (l3tlen < l3hlen) {
371 			SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
372 			err = EINVAL;
373 			goto fail;
374 		}
375 		if (pkt_len < l3tlen) {
376 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
377 			err = EINVAL;
378 			goto fail;
379 		}
380 		td->inet_ipver = IPVERSION;
381 		td->inet_proto = iph->ip_p;
382 		bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
383 		    sizeof(iph->ip_src));
384 		bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
385 		    sizeof(iph->ip_dst));
386 		break;
387 	case 6:
388 		l3hlen = sizeof(struct ip6_hdr);
389 		if (cls_len < l3hlen) {
390 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
391 			err = EINVAL;
392 			goto fail;
393 		}
394 		l3tlen = l3hlen + ntohs(ip6->ip6_plen);
395 		if (pkt_len < l3tlen) {
396 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
397 			err = EINVAL;
398 			goto fail;
399 		}
400 		td->inet_ipver = IPV6_VERSION;
401 		td->inet_proto = ip6->ip6_nxt;
402 		bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
403 		    sizeof(ip6->ip6_src));
404 		bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
405 		    sizeof(ip6->ip6_dst));
406 		break;
407 	default:
408 		SK_ERR("ipv == %d", ipv);
409 		err = EINVAL;
410 		goto fail;
411 	}
412 	tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
413 	ulen = (l3tlen - l3hlen);
414 	if (td->inet_proto == IPPROTO_TCP) {
415 		if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
416 			SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
417 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*tcph),
418 			    ulen, sizeof(*tcph));
419 			err = EINVAL;
420 			goto fail;
421 		}
422 		l4hlen = (uint8_t)(tcph->th_off << 2);
423 		if (l4hlen < sizeof(*tcph)) {
424 			SK_ERR("l4hlen < sizeof(*tcph) (%d < %d)", l4hlen, sizeof(*tcph));
425 			err = EINVAL;
426 			goto fail;
427 		}
428 		if (l4hlen > ulen) {
429 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
430 			err = EINVAL;
431 			goto fail;
432 		}
433 		bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
434 		    sizeof(td->inet_lport));
435 		bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
436 		    sizeof(td->inet_rport));
437 	} else if (td->inet_proto == IPPROTO_UDP) {
438 		if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
439 			SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
440 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*udph),
441 			    ulen, sizeof(*udph));
442 			err = EINVAL;
443 			goto fail;
444 		}
445 		l4hlen = sizeof(*udph);
446 		if (l4hlen > ulen) {
447 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
448 			err = EINVAL;
449 			goto fail;
450 		}
451 		bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
452 		    sizeof(td->inet_lport));
453 		bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
454 		    sizeof(td->inet_rport));
455 	} else {
456 		err = ENOTSUP;
457 		goto fail;
458 	}
459 
460 	td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
461 	td->inet_common.itd_len = sizeof(*td);
462 	td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
463 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
464 	td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
465 	    IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
466 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
467 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
468 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
469 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
470 	return 0;
471 fail:
472 	DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
473 	    uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
474 	    int, err);
475 	bzero(td, sizeof(*td));
476 	return err;
477 	#undef iph
478 	#undef ip6
479 	#undef tcph
480 	#undef udph
481 }
482 
483 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)484 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
485     struct __kern_packet *pkt, uint64_t *qset_id)
486 {
487 	struct ifnet_traffic_descriptor_inet td;
488 	int err;
489 
490 	err = fill_inet_td(pkt, &td);
491 	if (err != 0) {
492 		return err;
493 	}
494 	return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
495 }
496 
497 void
nxctl_traffic_rule_init(void)498 nxctl_traffic_rule_init(void)
499 {
500 	ASSERT(nxctl_traffic_rule_tag == NULL);
501 	nxctl_traffic_rule_tag =
502 	    kern_allocation_name_allocate(NXCTL_TRAFFIC_RULE_TAG, 0);
503 	ASSERT(nxctl_traffic_rule_tag != NULL);
504 
505 	ASSERT(inet_traffic_rule_type == NULL);
506 	inet_traffic_rule_type =
507 	    find_traffic_rule_type(IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET);
508 	ASSERT(inet_traffic_rule_type != NULL);
509 }
510 
511 void
nxctl_traffic_rule_fini(void)512 nxctl_traffic_rule_fini(void)
513 {
514 	if (nxctl_traffic_rule_tag != NULL) {
515 		kern_allocation_name_release(nxctl_traffic_rule_tag);
516 		nxctl_traffic_rule_tag = NULL;
517 	}
518 	inet_traffic_rule_type = NULL;
519 }
520 
521 static struct ifnet_ip_addr v6_zeros_addr = {0};
522 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)523 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
524 {
525 	return memcmp(a1, a2, sizeof(*a1)) == 0;
526 }
527 
528 SK_NO_INLINE_ATTRIBUTE
529 static struct nxctl_traffic_rule_storage *
nxctl_traffic_rule_storage_create(void)530 nxctl_traffic_rule_storage_create(void)
531 {
532 	struct nxctl_traffic_rule_storage *rs;
533 
534 	rs = sk_alloc_type(struct nxctl_traffic_rule_storage,
535 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
536 	SLIST_INIT(&rs->rs_list);
537 	rs->rs_count = 0;
538 	return rs;
539 }
540 
541 SK_NO_INLINE_ATTRIBUTE
542 static void
nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage * rs)543 nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage *rs)
544 {
545 	ASSERT(rs->rs_count == 0);
546 	ASSERT(SLIST_EMPTY(&rs->rs_list));
547 	sk_free_type(struct nxctl_traffic_rule_storage, rs);
548 }
549 
550 /*
551  * This is meant to be called during closure of the nxctl's fd.
552  * This will cleanup all rules linked to this nxctl. Rules that
553  * are marked persistent won't be added to the nxctl list.
554  */
555 void
nxctl_traffic_rule_clean(struct nxctl * nxctl)556 nxctl_traffic_rule_clean(struct nxctl *nxctl)
557 {
558 	struct nxctl_traffic_rule_storage *rs;
559 	struct nxctl_traffic_rule *ntr, *next;
560 	int err;
561 
562 	lck_mtx_lock(&nxctl->nxctl_lock);
563 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
564 		lck_mtx_unlock(&nxctl->nxctl_lock);
565 		return;
566 	}
567 	ntr = SLIST_FIRST(&rs->rs_list);
568 	SLIST_INIT(&rs->rs_list);
569 	rs->rs_count = 0;
570 	nxctl_traffic_rule_storage_destroy(rs);
571 	nxctl->nxctl_traffic_rule_storage = NULL;
572 	lck_mtx_unlock(&nxctl->nxctl_lock);
573 
574 	while (ntr != NULL) {
575 		next = SLIST_NEXT(ntr, ntr_storage_link);
576 		/*
577 		 * Clearing the flag to tell remove_traffic_rule() not to
578 		 * remove from the nxctl list again.
579 		 */
580 		ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
581 
582 		/* Passing NULL because we already hold a reference */
583 		err = remove_traffic_rule(nxctl, ntr->ntr_uuid, NULL);
584 		if (err == 0) {
585 			(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
586 		}
587 		release_traffic_rule(ntr);
588 		ntr = next;
589 	}
590 }
591 
592 SK_NO_INLINE_ATTRIBUTE
593 static void
add_traffic_rule_to_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)594 add_traffic_rule_to_nxctl(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
595 {
596 	struct nxctl_traffic_rule_storage *rs;
597 
598 	lck_mtx_lock(&nxctl->nxctl_lock);
599 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
600 		rs = nxctl_traffic_rule_storage_create();
601 		nxctl->nxctl_traffic_rule_storage = rs;
602 	}
603 	ntr->ntr_flags |= NTR_FLAG_ON_NXCTL_LIST;
604 	retain_traffic_rule(ntr);
605 	SLIST_INSERT_HEAD(&rs->rs_list, ntr, ntr_storage_link);
606 	rs->rs_count++;
607 	lck_mtx_unlock(&nxctl->nxctl_lock);
608 }
609 
610 SK_NO_INLINE_ATTRIBUTE
611 static void
remove_traffic_rule_from_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)612 remove_traffic_rule_from_nxctl(struct nxctl *nxctl,
613     struct nxctl_traffic_rule *ntr)
614 {
615 	struct nxctl_traffic_rule_storage *rs;
616 
617 	lck_mtx_lock(&nxctl->nxctl_lock);
618 	if ((ntr->ntr_flags & NTR_FLAG_ON_NXCTL_LIST) == 0) {
619 		lck_mtx_unlock(&nxctl->nxctl_lock);
620 		return;
621 	}
622 	rs = nxctl->nxctl_traffic_rule_storage;
623 	SLIST_REMOVE(&rs->rs_list, ntr, nxctl_traffic_rule, ntr_storage_link);
624 	rs->rs_count--;
625 	ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
626 	release_traffic_rule(ntr);
627 	if (rs->rs_count == 0) {
628 		nxctl_traffic_rule_storage_destroy(rs);
629 		nxctl->nxctl_traffic_rule_storage = NULL;
630 	}
631 	lck_mtx_unlock(&nxctl->nxctl_lock);
632 }
633 
634 static int
inet_traffic_rule_validate(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)635 inet_traffic_rule_validate(struct nxctl_traffic_rule_type *type,
636     const char *ifname,
637     struct ifnet_traffic_descriptor_common *td,
638     struct ifnet_traffic_rule_action *ra)
639 {
640 #pragma unused(type)
641 	char buf[IFNAMSIZ];
642 	int unit, i;
643 	struct ifnet_traffic_descriptor_inet *tdi;
644 	uint8_t mask = 0, ipver, proto;
645 
646 	if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
647 		SK_ERR("invalid ifname: %s", ifname);
648 		return EINVAL;
649 	}
650 	if (td->itd_len != sizeof(*tdi)) {
651 		SK_ERR("invalid td len: expected %d, actual %d",
652 		    sizeof(*tdi), td->itd_len);
653 		return EINVAL;
654 	}
655 	if (td->itd_flags == 0 ||
656 	    (td->itd_flags &
657 	    ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
658 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
659 		SK_ERR("invalid td flags: 0x%x", td->itd_flags);
660 		return EINVAL;
661 	}
662 	tdi = (struct ifnet_traffic_descriptor_inet *)td;
663 	for (i = 0; i < NINETRULEMASKS; i++) {
664 		if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
665 			mask = tdi->inet_mask;
666 			break;
667 		}
668 	}
669 	if (mask == 0) {
670 		SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
671 		return EINVAL;
672 	}
673 	ipver = tdi->inet_ipver;
674 	if (ipver != IPVERSION && ipver != IPV6_VERSION) {
675 		SK_ERR("invalid inet ipver: 0x%x", ipver);
676 		return EINVAL;
677 	}
678 	proto = tdi->inet_proto;
679 	if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
680 		SK_ERR("invalid inet proto: %d", proto);
681 		return EINVAL;
682 	}
683 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
684 		if (ipver == IPVERSION) {
685 			if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
686 				SK_ERR("inet laddr v4 cannot be unspecified");
687 				return EINVAL;
688 			}
689 		} else {
690 			if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
691 				SK_ERR("inet laddr v4 cannot be unspecified");
692 				return EINVAL;
693 			}
694 		}
695 	}
696 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
697 		if (ipver == IPVERSION) {
698 			if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
699 				SK_ERR("inet raddr v6 cannot be unspecified");
700 				return EINVAL;
701 			}
702 		} else {
703 			if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
704 				SK_ERR("inet raddr v6 cannot be unspecified");
705 				return EINVAL;
706 			}
707 		}
708 	}
709 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
710 		if (tdi->inet_lport == 0) {
711 			SK_ERR("inet lport cannot be unspecified");
712 			return EINVAL;
713 		}
714 	}
715 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
716 		if (tdi->inet_rport == 0) {
717 			SK_ERR("inet rport cannot be unspecified");
718 			return EINVAL;
719 		}
720 	}
721 	if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
722 		SK_ERR("invalid ra len: expected %d, actual %d",
723 		    sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
724 		return EINVAL;
725 	}
726 	return 0;
727 }
728 
729 SK_NO_INLINE_ATTRIBUTE
730 static struct nxctl_traffic_rule_inet_storage *
inet_traffic_rule_storage_create(void)731 inet_traffic_rule_storage_create(void)
732 {
733 	struct nxctl_traffic_rule_inet_storage *rs;
734 
735 	rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
736 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
737 	SLIST_INIT(&rs->ris_if_list);
738 	rs->ris_count = 0;
739 	return rs;
740 }
741 
742 SK_NO_INLINE_ATTRIBUTE
743 static void
inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage * rs)744 inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage *rs)
745 {
746 	ASSERT(rs->ris_count == 0);
747 	ASSERT(SLIST_EMPTY(&rs->ris_if_list));
748 	sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
749 }
750 
751 SK_NO_INLINE_ATTRIBUTE
752 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)753 inet_traffic_rule_if_create(const char *ifname)
754 {
755 	struct nxctl_traffic_rule_inet_if *rif;
756 	int i;
757 
758 	rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
759 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
760 	for (i = 0; i < NINETRULEMASKS; i++) {
761 		SLIST_INIT(&rif->rii_lists[i]);
762 	}
763 	strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
764 	rif->rii_count = 0;
765 	return rif;
766 }
767 
768 SK_NO_INLINE_ATTRIBUTE
769 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)770 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
771 {
772 	int i;
773 
774 	for (i = 0; i < NINETRULEMASKS; i++) {
775 		ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
776 	}
777 	ASSERT(rif->rii_count == 0);
778 	sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
779 }
780 
781 SK_NO_INLINE_ATTRIBUTE
782 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)783 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
784     uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
785 {
786 	struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
787 	struct ifnet_traffic_descriptor_inet *tdi0;
788 	uint8_t mask;
789 	boolean_t exact;
790 
791 	VERIFY(strcmp(ntr->ntr_ifname, ifname) == 0);
792 	tdi0 = &ntri->ntri_td;
793 
794 	exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
795 	mask = tdi0->inet_mask & tdi->inet_mask;
796 	if (exact) {
797 		ASSERT(tdi0->inet_mask == tdi->inet_mask);
798 	}
799 	ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
800 	if (tdi0->inet_ipver != tdi->inet_ipver) {
801 		DTRACE_SKYWALK2(ipver__mismatch,
802 		    uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
803 		return FALSE;
804 	}
805 	ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0);
806 	if (tdi0->inet_proto != tdi->inet_proto) {
807 		DTRACE_SKYWALK2(proto__mismatch,
808 		    uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
809 		return FALSE;
810 	}
811 	if (tdi0->inet_ipver == IPVERSION) {
812 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
813 		    tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
814 			DTRACE_SKYWALK2(v4laddr__mismatch,
815 			    in_addr_t, tdi0->inet_laddr.iia_v4addr,
816 			    in_addr_t, tdi->inet_laddr.iia_v4addr);
817 			return FALSE;
818 		}
819 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
820 		    tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
821 			DTRACE_SKYWALK2(v4raddr__mismatch,
822 			    in_addr_t, tdi0->inet_raddr.iia_v4addr,
823 			    in_addr_t, tdi->inet_raddr.iia_v4addr);
824 			return FALSE;
825 		}
826 	} else {
827 		ASSERT(tdi0->inet_ipver == IPV6_VERSION);
828 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
829 		    !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
830 			DTRACE_SKYWALK2(v6laddr__mismatch,
831 			    struct in6_addr *, &tdi0->inet_laddr,
832 			    struct in6_addr *, &tdi->inet_laddr);
833 			return FALSE;
834 		}
835 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
836 		    !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
837 			DTRACE_SKYWALK2(v6raddr__mismatch,
838 			    struct in6_addr *, &tdi0->inet_raddr,
839 			    struct in6_addr *, &tdi->inet_raddr);
840 			return FALSE;
841 		}
842 	}
843 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
844 	    tdi0->inet_lport != tdi->inet_lport) {
845 		DTRACE_SKYWALK2(lport__mismatch,
846 		    uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
847 		return FALSE;
848 	}
849 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
850 	    tdi0->inet_rport != tdi->inet_rport) {
851 		DTRACE_SKYWALK2(rport__mismatch,
852 		    uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
853 		return FALSE;
854 	}
855 	return TRUE;
856 }
857 
858 static int
inet_traffic_rule_find(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)859 inet_traffic_rule_find(struct nxctl_traffic_rule_type *type, const char *ifname,
860     struct ifnet_traffic_descriptor_common *td, uint32_t flags,
861     struct nxctl_traffic_rule **ntrp)
862 {
863 	struct nxctl_traffic_rule_inet *ntri = NULL;
864 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
865 	struct nxctl_traffic_rule_inet_if *rif;
866 	struct ifnet_traffic_descriptor_inet *tdi =
867 	    (struct ifnet_traffic_descriptor_inet *)td;
868 	int i;
869 
870 	if (rs == NULL) {
871 		return ENOENT;
872 	}
873 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
874 		if (strcmp(rif->rii_ifname, ifname) != 0) {
875 			continue;
876 		}
877 		for (i = 0; i < NINETRULEMASKS; i++) {
878 			if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
879 			    tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
880 				continue;
881 			}
882 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
883 				if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
884 					*ntrp = (struct nxctl_traffic_rule *)ntri;
885 					return 0;
886 				}
887 			}
888 		}
889 	}
890 	return ENOENT;
891 }
892 
893 static int
inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type * type,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)894 inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type *type,
895     uuid_t uuid, struct nxctl_traffic_rule **ntrp)
896 {
897 	struct nxctl_traffic_rule_inet *ntri;
898 	struct nxctl_traffic_rule *ntr;
899 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
900 	struct nxctl_traffic_rule_inet_if *rif;
901 	int i;
902 
903 	if (rs == NULL) {
904 		return ENOENT;
905 	}
906 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
907 		for (i = 0; i < NINETRULEMASKS; i++) {
908 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
909 				ntr = &ntri->ntri_common;
910 				if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
911 					*ntrp = ntr;
912 					return 0;
913 				}
914 			}
915 		}
916 	}
917 	return ENOENT;
918 }
919 
920 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)921 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
922 {
923 	struct ifnet *ifp;
924 
925 	ifp = ifunit_ref(ifname);
926 	if (ifp == NULL) {
927 		DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
928 		return;
929 	}
930 	ifnet_update_traffic_rule_count(ifp, count);
931 	ifnet_decr_iorefcnt(ifp);
932 }
933 
934 static void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)935 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
936 {
937 	struct nxctl_traffic_rule_type *type = ntr->ntr_type;
938 	struct nxctl_traffic_rule_inet_storage *rs;
939 	struct nxctl_traffic_rule_inet_if *rif;
940 	struct nxctl_traffic_rule_inet *ntri =
941 	    (struct nxctl_traffic_rule_inet *)ntr;
942 	struct nxctl_traffic_rule_inet_head *list = NULL;
943 	int i;
944 
945 	if ((rs = type->ntrt_storage) == NULL) {
946 		rs = inet_traffic_rule_storage_create();
947 		type->ntrt_storage = rs;
948 	}
949 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
950 		if (strcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
951 			break;
952 		}
953 	}
954 	if (rif == NULL) {
955 		rif = inet_traffic_rule_if_create(ntr->ntr_ifname);
956 		SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
957 	}
958 	for (i = 0; i < NINETRULEMASKS; i++) {
959 		if (ntri->ntri_td.inet_mask ==
960 		    nxctl_inet_traffic_rule_masks[i]) {
961 			list = &rif->rii_lists[i];
962 			break;
963 		}
964 	}
965 	retain_traffic_rule(ntr);
966 	ASSERT(list != NULL);
967 	SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
968 	/* per-interface count */
969 	rif->rii_count++;
970 	inet_update_ifnet_traffic_rule_count(rif->rii_ifname, rif->rii_count);
971 
972 	/* global count */
973 	rs->ris_count++;
974 }
975 
976 static void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)977 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
978 {
979 	struct nxctl_traffic_rule_inet_storage *rs;
980 	struct nxctl_traffic_rule_inet_if *rif;
981 	struct nxctl_traffic_rule_inet *ntri =
982 	    (struct nxctl_traffic_rule_inet *)ntr;
983 	struct nxctl_traffic_rule_inet_head *list = NULL;
984 	struct nxctl_traffic_rule_type *type;
985 	int i;
986 
987 	type = ntr->ntr_type;
988 	rs = type->ntrt_storage;
989 	ASSERT(rs != NULL);
990 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
991 		if (strcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
992 			break;
993 		}
994 	}
995 	ASSERT(rif != NULL);
996 	for (i = 0; i < NINETRULEMASKS; i++) {
997 		if (ntri->ntri_td.inet_mask ==
998 		    nxctl_inet_traffic_rule_masks[i]) {
999 			list = &rif->rii_lists[i];
1000 			break;
1001 		}
1002 	}
1003 	ASSERT(list != NULL);
1004 	SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
1005 	rif->rii_count--;
1006 	inet_update_ifnet_traffic_rule_count(rif->rii_ifname, rif->rii_count);
1007 
1008 	rs->ris_count--;
1009 	release_traffic_rule(ntr);
1010 
1011 	if (rif->rii_count == 0) {
1012 		SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
1013 		inet_traffic_rule_if_destroy(rif);
1014 	}
1015 	if (rs->ris_count == 0) {
1016 		type->ntrt_storage = NULL;
1017 		inet_traffic_rule_storage_destroy(rs);
1018 	}
1019 }
1020 
1021 /*
1022  * XXX
1023  * This may need additional changes to ensure safety against detach/attach.
1024  * This is not an issue for the first consumer of llink interfaces, cellular,
1025  * which does not detach.
1026  */
1027 static int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)1028 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
1029 {
1030 	struct ifnet *ifp;
1031 	struct nx_netif *nif;
1032 	struct netif_qset *qset = NULL;
1033 	struct nxctl_traffic_rule_inet *ntri;
1034 	int err = 0;
1035 
1036 	ifp = ifunit_ref(ntr->ntr_ifname);
1037 	if (ifp == NULL) {
1038 		DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
1039 		err = ENXIO;
1040 		goto done;
1041 	}
1042 	nif = NA(ifp)->nifna_netif;
1043 	if (!NX_LLINK_PROV(nif->nif_nx)) {
1044 		DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
1045 		err = ENOTSUP;
1046 		goto done;
1047 	}
1048 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1049 	qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
1050 	err = nx_netif_notify_steering_info(nif, qset,
1051 	    (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
1052 	    ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
1053 done:
1054 	if (qset != NULL) {
1055 		nx_netif_qset_release(&qset);
1056 	}
1057 	if (ifp != NULL) {
1058 		ifnet_decr_iorefcnt(ifp);
1059 	}
1060 	return err;
1061 }
1062 
1063 static int
inet_traffic_rule_create(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1064 inet_traffic_rule_create(struct nxctl_traffic_rule_type *type,
1065     const char *ifname, struct ifnet_traffic_descriptor_common *td,
1066     struct ifnet_traffic_rule_action *ra, uint32_t flags,
1067     struct nxctl_traffic_rule **ntrp)
1068 {
1069 	struct nxctl_traffic_rule_inet *ntri;
1070 	struct nxctl_traffic_rule *ntr;
1071 
1072 	ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
1073 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
1074 	ntr = &ntri->ntri_common;
1075 
1076 	ntr->ntr_type = type;
1077 	ntr->ntr_flags = flags;
1078 	uuid_generate(ntr->ntr_uuid);
1079 	os_ref_init(&ntr->ntr_refcnt, NULL);
1080 
1081 	strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
1082 	proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
1083 	bcopy(td, &ntri->ntri_td, sizeof(ntri->ntri_td));
1084 	bcopy(ra, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
1085 
1086 	*ntrp = ntr;
1087 	return 0;
1088 }
1089 
1090 static void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)1091 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
1092 {
1093 	struct nxctl_traffic_rule_inet *ntri;
1094 
1095 	ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
1096 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1097 	sk_free_type(struct nxctl_traffic_rule_inet, ntri);
1098 }
1099 
1100 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)1101 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
1102     struct nxctl_traffic_rule_inet_iocinfo *info)
1103 {
1104 	struct nxctl_traffic_rule *ntr;
1105 	struct nxctl_traffic_rule_generic_iocinfo *ginfo;
1106 
1107 	bzero(info, sizeof(*info));
1108 	ntr = &ntri->ntri_common;
1109 	ginfo = &info->tri_common;
1110 	_CASSERT(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
1111 	_CASSERT(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
1112 	uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
1113 	strlcpy(ginfo->trg_procname, ntr->ntr_procname,
1114 	    sizeof(ginfo->trg_procname));
1115 	strlcpy(ginfo->trg_ifname, ntr->ntr_ifname,
1116 	    sizeof(ginfo->trg_ifname));
1117 	bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
1118 	bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
1119 }
1120 
1121 static int
inet_traffic_rule_get_all(struct nxctl_traffic_rule_type * type,uint32_t size,uint32_t * count,user_addr_t uaddr)1122 inet_traffic_rule_get_all(struct nxctl_traffic_rule_type *type, uint32_t size,
1123     uint32_t *count, user_addr_t uaddr)
1124 {
1125 	struct nxctl_traffic_rule_inet *ntri = NULL;
1126 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
1127 	struct nxctl_traffic_rule_inet_if *rif;
1128 	struct nxctl_traffic_rule_inet_iocinfo info;
1129 	int i, err;
1130 
1131 	if (size != sizeof(info)) {
1132 		SK_ERR("size: actual %d, expected %d", size, sizeof(info));
1133 		return EINVAL;
1134 	}
1135 	if (rs == NULL) {
1136 		*count = 0;
1137 		return 0;
1138 	}
1139 	if (*count < rs->ris_count) {
1140 		SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
1141 		return ENOBUFS;
1142 	}
1143 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1144 		for (i = 0; i < NINETRULEMASKS; i++) {
1145 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
1146 				convert_ntri_to_iocinfo(ntri, &info);
1147 				err = copyout(&info, uaddr, sizeof(info));
1148 				if (err != 0) {
1149 					SK_ERR("copyout failed: %d", err);
1150 					return err;
1151 				}
1152 				uaddr += sizeof(info);
1153 			}
1154 		}
1155 	}
1156 	*count = rs->ris_count;
1157 	return 0;
1158 }
1159 
1160 SK_NO_INLINE_ATTRIBUTE
1161 static void
retain_traffic_rule(struct nxctl_traffic_rule * ntr)1162 retain_traffic_rule(struct nxctl_traffic_rule *ntr)
1163 {
1164 #if (DEVELOPMENT || DEBUG)
1165 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1166 	DTRACE_SKYWALK2(ntr__retain, struct nxctl_traffic_rule *, ntr,
1167 	    os_ref_count_t, count);
1168 #endif
1169 	os_ref_retain(&ntr->ntr_refcnt);
1170 }
1171 
1172 SK_NO_INLINE_ATTRIBUTE
1173 static void
release_traffic_rule(struct nxctl_traffic_rule * ntr)1174 release_traffic_rule(struct nxctl_traffic_rule *ntr)
1175 {
1176 #if (DEVELOPMENT || DEBUG)
1177 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1178 	DTRACE_SKYWALK2(ntr__release, struct nxctl_traffic_rule *, ntr,
1179 	    os_ref_count_t, count);
1180 #endif
1181 	if (os_ref_release(&ntr->ntr_refcnt) == 0) {
1182 		ntr->ntr_type->ntrt_destroy(ntr);
1183 	}
1184 }
1185 
1186 SK_NO_INLINE_ATTRIBUTE
1187 static int
notify_traffic_rule(struct nxctl_traffic_rule * ntr,uint32_t flags)1188 notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags)
1189 {
1190 	return ntr->ntr_type->ntrt_notify(ntr, flags);
1191 }
1192 
1193 static void
link_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1194 link_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1195 {
1196 	/*
1197 	 * The persist flag means: do not clean up rule upon nxctl fd close.
1198 	 * This means we only add the rule to the nxctl list if persist
1199 	 * is not set.
1200 	 */
1201 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1202 		add_traffic_rule_to_nxctl(nxctl, ntr);
1203 	}
1204 	ntr->ntr_type->ntrt_link(ntr);
1205 }
1206 
1207 static void
unlink_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1208 unlink_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1209 {
1210 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1211 		remove_traffic_rule_from_nxctl(nxctl, ntr);
1212 	}
1213 	ntr->ntr_type->ntrt_unlink(ntr);
1214 }
1215 
1216 static int
find_traffic_rule_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1217 find_traffic_rule_by_uuid(uuid_t uuid, struct nxctl_traffic_rule **ntrp)
1218 {
1219 	int i, err;
1220 	struct nxctl_traffic_rule_type *ntrt;
1221 	struct nxctl_traffic_rule *ntr = NULL;
1222 
1223 	for (i = 0; i < NRULETYPES; i++) {
1224 		ntrt = &nxctl_rule_types[i];
1225 		err = ntrt->ntrt_find_by_uuid(ntrt, uuid, &ntr);
1226 		if (err == 0) {
1227 			ASSERT(ntr != NULL);
1228 			*ntrp = ntr;
1229 			return 0;
1230 		}
1231 	}
1232 	return ENOENT;
1233 }
1234 
1235 static struct nxctl_traffic_rule_type *
find_traffic_rule_type(uint8_t type)1236 find_traffic_rule_type(uint8_t type)
1237 {
1238 	int i;
1239 	struct nxctl_traffic_rule_type *ntrt;
1240 
1241 	for (i = 0; i < NRULETYPES; i++) {
1242 		ntrt = &nxctl_rule_types[i];
1243 		if (ntrt->ntrt_type == type) {
1244 			return ntrt;
1245 		}
1246 	}
1247 	return NULL;
1248 }
1249 
1250 SK_NO_INLINE_ATTRIBUTE
1251 static int
add_traffic_rule(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1252 add_traffic_rule(struct nxctl *nxctl, const char *ifname,
1253     struct ifnet_traffic_descriptor_common *td,
1254     struct ifnet_traffic_rule_action *ra,
1255     uint32_t flags,
1256     struct nxctl_traffic_rule **ntrp)
1257 {
1258 	struct nxctl_traffic_rule_type *type = NULL;
1259 	struct nxctl_traffic_rule *ntr = NULL;
1260 	int err;
1261 
1262 	NXTR_WLOCK();
1263 	type = find_traffic_rule_type(td->itd_type);
1264 	if (type == NULL) {
1265 		SK_ERR("rule type %x not found", td->itd_type);
1266 		err = EINVAL;
1267 		goto fail;
1268 	}
1269 	err = type->ntrt_validate(type, ifname, td, ra);
1270 	if (err != 0) {
1271 		SK_ERR("rule validate failed: %d", err);
1272 		goto fail;
1273 	}
1274 	err = type->ntrt_find(type, ifname, td, NTR_FIND_FLAG_EXACT, &ntr);
1275 	if (err == 0) {
1276 		SK_ERR("rule already exists");
1277 		ASSERT(ntr != NULL);
1278 		err = EEXIST;
1279 		goto fail;
1280 	} else if (err != ENOENT) {
1281 		SK_ERR("rule find failed: %d", err);
1282 		goto fail;
1283 	}
1284 	err = type->ntrt_create(type, ifname, td, ra, flags, &ntr);
1285 	if (err != 0) {
1286 		SK_ERR("rule create failed: %d", err);
1287 		goto fail;
1288 	}
1289 	link_traffic_rule(nxctl, ntr);
1290 	if (ntrp != NULL) {
1291 		retain_traffic_rule(ntr);
1292 		*ntrp = ntr;
1293 	}
1294 	NXTR_WUNLOCK();
1295 	return 0;
1296 fail:
1297 	NXTR_WUNLOCK();
1298 	return err;
1299 }
1300 
1301 
1302 SK_NO_INLINE_ATTRIBUTE
1303 static int
remove_traffic_rule(struct nxctl * nxctl,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1304 remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
1305     struct nxctl_traffic_rule **ntrp)
1306 {
1307 	struct nxctl_traffic_rule *ntr;
1308 	int err;
1309 
1310 	NXTR_WLOCK();
1311 	err = find_traffic_rule_by_uuid(uuid, &ntr);
1312 	if (err != 0) {
1313 		SK_ERR("traffic rule not found");
1314 		NXTR_WUNLOCK();
1315 		return err;
1316 	}
1317 	if (ntrp != NULL) {
1318 		retain_traffic_rule(ntr);
1319 		*ntrp = ntr;
1320 	}
1321 	unlink_traffic_rule(nxctl, ntr);
1322 	/* release initial reference */
1323 	release_traffic_rule(ntr);
1324 	NXTR_WUNLOCK();
1325 	return 0;
1326 }
1327 
1328 static uint32_t
convert_traffic_rule_ioc_flags(uint32_t flags)1329 convert_traffic_rule_ioc_flags(uint32_t flags)
1330 {
1331 	uint32_t f = 0;
1332 
1333 	if ((flags & NXIOC_ADD_TRAFFIC_RULE_FLAG_PERSIST) != 0) {
1334 		f |= NTR_FLAG_PERSIST;
1335 	}
1336 	return f;
1337 }
1338 
1339 SK_NO_INLINE_ATTRIBUTE
1340 static int
add_traffic_rule_generic(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,uuid_t * uuid)1341 add_traffic_rule_generic(struct nxctl *nxctl, const char *ifname,
1342     struct ifnet_traffic_descriptor_common *td,
1343     struct ifnet_traffic_rule_action *ra, uint32_t flags, uuid_t *uuid)
1344 {
1345 	struct nxctl_traffic_rule *ntr;
1346 	int err;
1347 
1348 	err = add_traffic_rule(nxctl, ifname, td, ra, flags, &ntr);
1349 	if (err != 0) {
1350 		return err;
1351 	}
1352 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_ADD);
1353 	uuid_copy(*uuid, ntr->ntr_uuid);
1354 	release_traffic_rule(ntr);
1355 	return 0;
1356 }
1357 
1358 int
nxioctl_add_traffic_rule_inet(struct nxctl * nxctl,caddr_t data,proc_t procp)1359 nxioctl_add_traffic_rule_inet(struct nxctl *nxctl, caddr_t data, proc_t procp)
1360 {
1361 #pragma unused(procp)
1362 	struct nxctl_add_traffic_rule_inet_iocargs *args =
1363 	    (struct nxctl_add_traffic_rule_inet_iocargs *)(void *)data;
1364 
1365 	return add_traffic_rule_generic(nxctl, args->atri_ifname,
1366 	           (struct ifnet_traffic_descriptor_common *)&args->atri_td,
1367 	           (struct ifnet_traffic_rule_action *)&args->atri_ra,
1368 	           convert_traffic_rule_ioc_flags(args->atri_flags),
1369 	           &args->atri_uuid);
1370 }
1371 
1372 int
nxioctl_remove_traffic_rule(struct nxctl * nxctl,caddr_t data,proc_t procp)1373 nxioctl_remove_traffic_rule(struct nxctl *nxctl, caddr_t data, proc_t procp)
1374 {
1375 #pragma unused(procp)
1376 	struct nxctl_remove_traffic_rule_iocargs *args =
1377 	    (struct nxctl_remove_traffic_rule_iocargs *)(void *)data;
1378 	struct nxctl_traffic_rule *ntr;
1379 	int err;
1380 
1381 	err = remove_traffic_rule(nxctl, args->rtr_uuid, &ntr);
1382 	if (err != 0) {
1383 		return err;
1384 	}
1385 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
1386 	release_traffic_rule(ntr);
1387 	return 0;
1388 }
1389 
1390 int
nxioctl_get_traffic_rules(struct nxctl * nxctl,caddr_t data,proc_t procp)1391 nxioctl_get_traffic_rules(struct nxctl *nxctl, caddr_t data, proc_t procp)
1392 {
1393 #pragma unused(nxctl)
1394 	struct nxctl_get_traffic_rules_iocargs *args =
1395 	    (struct nxctl_get_traffic_rules_iocargs *)(void *)data;
1396 	struct nxctl_traffic_rule_type *type;
1397 	user_addr_t uaddr;
1398 	int err;
1399 
1400 	NXTR_RLOCK();
1401 	type = find_traffic_rule_type(args->gtr_type);
1402 	if (type == NULL) {
1403 		SK_ERR("rule type %x not found", args->gtr_type);
1404 		err = EINVAL;
1405 		goto fail;
1406 	}
1407 	uaddr = proc_is64bit(procp) ? args->gtr_buf64 :
1408 	    CAST_USER_ADDR_T(args->gtr_buf);
1409 	err = type->ntrt_get_all(type, args->gtr_size, &args->gtr_count, uaddr);
1410 	if (err != 0) {
1411 		goto fail;
1412 	}
1413 	NXTR_RUNLOCK();
1414 	return 0;
1415 fail:
1416 	NXTR_RUNLOCK();
1417 	return err;
1418 }
1419