xref: /xnu-11215.61.5/bsd/skywalk/nexus/nexus_traffic_rule.c (revision 4f1223e81cd707a65cc109d0b8ad6653699da3c4)
1 /*
2  * Copyright (c) 2022 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 #include <skywalk/os_skywalk_private.h>
29 #include <skywalk/nexus/netif/nx_netif.h>
30 #include <netinet/ip.h>
31 #include <netinet/ip6.h>
32 #include <netinet/tcp.h>
33 #include <netinet/udp.h>
34 #include <sys/sdt.h>
35 
36 /*
37  * Implementation of nexus traffic rules APIs.
38  */
39 
40 struct nxctl_traffic_rule_type;
41 struct nxctl_traffic_rule;
42 
43 /*
44  * These callbacks need to be implemented for each rule type.
45  */
46 
47 /* Validate user provided parameters. */
48 typedef int (nxctl_traffic_rule_validate_cb_t)(
49 	struct nxctl_traffic_rule_type *type,
50 	const char *ifname,
51 	struct ifnet_traffic_descriptor_common *td,
52 	struct ifnet_traffic_rule_action *ra);
53 /*
54  * Each rule type has its own global structure for storing rules.
55  * These callbacks access this global structure.
56  */
57 #define NTR_FIND_FLAG_EXACT 0x0001
58 typedef int (nxctl_traffic_rule_find_cb_t)(
59 	struct nxctl_traffic_rule_type *type,
60 	const char *ifname,
61 	struct ifnet_traffic_descriptor_common *td,
62 	uint32_t flags,
63 	struct nxctl_traffic_rule **ntrp);
64 
65 typedef int (nxctl_traffic_rule_find_by_uuid_cb_t)(
66 	struct nxctl_traffic_rule_type *type,
67 	uuid_t uuid,
68 	struct nxctl_traffic_rule **ntrp);
69 
70 typedef void (nxctl_traffic_rule_link_cb_t)(
71 	struct nxctl_traffic_rule *ntr);
72 
73 typedef void (nxctl_traffic_rule_unlink_cb_t)(
74 	struct nxctl_traffic_rule *ntr);
75 
76 /*
77  * Notifies lower layers of the addition/removal of a rule.
78  * This is called outside of nxctl_traffic_rule_lock to avoid potential
79  * locking issues.
80  */
81 #define NTR_NOTIFY_FLAG_ADD 0x0001
82 #define NTR_NOTIFY_FLAG_REMOVE 0x0002
83 typedef int (nxctl_traffic_rule_notify_cb_t)(
84 	struct nxctl_traffic_rule *ntr,
85 	uint32_t flags);
86 
87 /*
88  * Create/Destroy callbacks for a rule type.
89  */
90 typedef int (nxctl_traffic_rule_create_cb_t)(
91 	struct nxctl_traffic_rule_type *type,
92 	const char *ifname,
93 	struct ifnet_traffic_descriptor_common *td,
94 	struct ifnet_traffic_rule_action *ra,
95 	uint32_t flags,
96 	struct nxctl_traffic_rule **ntrp);
97 
98 typedef void (nxctl_traffic_rule_destroy_cb_t)(
99 	struct nxctl_traffic_rule *ntr);
100 
101 /*
102  * This is used for copying all rules for a type (including generic
103  * and type-specific info) to userspace.
104  */
105 typedef int (nxctl_traffic_rule_get_all_cb_t)(
106 	struct nxctl_traffic_rule_type *type,
107 	uint32_t size,
108 	uint32_t *count,
109 	user_addr_t uaddr);
110 
111 struct nxctl_traffic_rule_type {
112 	uint8_t ntrt_type;
113 	nxctl_traffic_rule_validate_cb_t *ntrt_validate;
114 	nxctl_traffic_rule_find_cb_t *ntrt_find;
115 	nxctl_traffic_rule_find_by_uuid_cb_t *ntrt_find_by_uuid;
116 	nxctl_traffic_rule_link_cb_t *ntrt_link;
117 	nxctl_traffic_rule_unlink_cb_t *ntrt_unlink;
118 	nxctl_traffic_rule_notify_cb_t *ntrt_notify;
119 	nxctl_traffic_rule_create_cb_t *ntrt_create;
120 	nxctl_traffic_rule_destroy_cb_t *ntrt_destroy;
121 	nxctl_traffic_rule_get_all_cb_t *ntrt_get_all;
122 	/*
123 	 * -fbounds-safety: Why is this void *? All usage of this was for
124 	 * struct nxctl_traffic_rule_inet_storage *
125 	 */
126 	struct nxctl_traffic_rule_inet_storage *ntrt_storage;
127 };
128 
129 static nxctl_traffic_rule_validate_cb_t inet_traffic_rule_validate;
130 static nxctl_traffic_rule_find_cb_t inet_traffic_rule_find;
131 static nxctl_traffic_rule_find_by_uuid_cb_t inet_traffic_rule_find_by_uuid;
132 static nxctl_traffic_rule_link_cb_t inet_traffic_rule_link;
133 static nxctl_traffic_rule_unlink_cb_t inet_traffic_rule_unlink;
134 static nxctl_traffic_rule_notify_cb_t inet_traffic_rule_notify;
135 static nxctl_traffic_rule_create_cb_t inet_traffic_rule_create;
136 static nxctl_traffic_rule_destroy_cb_t inet_traffic_rule_destroy;
137 static nxctl_traffic_rule_get_all_cb_t inet_traffic_rule_get_all;
138 
139 static struct nxctl_traffic_rule_type nxctl_rule_types[] = {
140 	{
141 		.ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET,
142 		.ntrt_validate = inet_traffic_rule_validate,
143 		.ntrt_find = inet_traffic_rule_find,
144 		.ntrt_find_by_uuid = inet_traffic_rule_find_by_uuid,
145 		.ntrt_link = inet_traffic_rule_link,
146 		.ntrt_unlink = inet_traffic_rule_unlink,
147 		.ntrt_notify = inet_traffic_rule_notify,
148 		.ntrt_create = inet_traffic_rule_create,
149 		.ntrt_destroy = inet_traffic_rule_destroy,
150 		.ntrt_get_all = inet_traffic_rule_get_all,
151 	},
152 };
153 #define NRULETYPES \
154     (sizeof(nxctl_rule_types)/sizeof(struct nxctl_traffic_rule_type))
155 
156 /*
157  * Generic traffic rule.
158  * Contains fields common to all traffic rules.
159  */
160 #define NTR_FLAG_PERSIST 0x0001
161 #define NTR_FLAG_ON_NXCTL_LIST 0x0002
162 struct nxctl_traffic_rule {
163 	struct nxctl_traffic_rule_type *ntr_type;
164 	uint32_t ntr_flags;
165 	os_refcnt_t ntr_refcnt;
166 	uuid_t ntr_uuid;
167 	char ntr_procname[NTR_PROCNAME_SZ];
168 	char ntr_ifname[IFNAMSIZ];
169 	SLIST_ENTRY(nxctl_traffic_rule) ntr_storage_link;
170 };
171 
172 /*
173  * Inet-specific traffic rule.
174  */
175 struct nxctl_traffic_rule_inet {
176 	struct nxctl_traffic_rule ntri_common;
177 	SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
178 	struct ifnet_traffic_descriptor_inet ntri_td;
179 	struct ifnet_traffic_rule_action_steer ntri_ra;
180 };
181 
182 /*
183  * Currently supported tuple types.
184  */
185 #define ITDBIT(set, bit) (((set) != 0) ? (bit) : 0)
186 #define ITRM(proto, laddr, raddr, lport, rport) \
187 	(IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER | \
188 	ITDBIT(proto, IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) | \
189 	ITDBIT(laddr, IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) | \
190 	ITDBIT(raddr, IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) | \
191 	ITDBIT(lport, IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) | \
192 	ITDBIT(rport, IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT))
193 
194 static uint8_t nxctl_inet_traffic_rule_masks[] = {
195 	ITRM(1, 1, 1, 1, 1),
196 	ITRM(1, 1, 1, 1, 0),
197 	ITRM(1, 1, 1, 0, 1),
198 	ITRM(1, 1, 1, 0, 0),
199 	ITRM(1, 1, 0, 1, 1),
200 	ITRM(1, 1, 0, 1, 0),
201 	ITRM(1, 1, 0, 0, 1),
202 	ITRM(1, 1, 0, 0, 0),
203 	ITRM(1, 0, 1, 1, 1),
204 	ITRM(1, 0, 1, 1, 0),
205 	ITRM(1, 0, 1, 0, 1),
206 	ITRM(1, 0, 1, 0, 0),
207 	ITRM(1, 0, 0, 1, 1),
208 	ITRM(1, 0, 0, 1, 0),
209 	ITRM(1, 0, 0, 0, 1),
210 	// ITRM(1, 0, 0, 0, 0), addr or port is required
211 	ITRM(0, 1, 1, 1, 1),
212 	ITRM(0, 1, 1, 1, 0),
213 	ITRM(0, 1, 1, 0, 1),
214 	ITRM(0, 1, 1, 0, 0),
215 	ITRM(0, 1, 0, 1, 1),
216 	ITRM(0, 1, 0, 1, 0),
217 	ITRM(0, 1, 0, 0, 1),
218 	ITRM(0, 1, 0, 0, 0),
219 	ITRM(0, 0, 1, 1, 1),
220 	ITRM(0, 0, 1, 1, 0),
221 	ITRM(0, 0, 1, 0, 1),
222 	ITRM(0, 0, 1, 0, 0),
223 	ITRM(0, 0, 0, 1, 1),
224 	ITRM(0, 0, 0, 1, 0),
225 	ITRM(0, 0, 0, 0, 1),
226 	// ITRM(0, 0, 0, 0, 0),
227 };
228 #define NINETRULEMASKS \
229     (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
230 
231 /* Per-interface lists of traffic rules */
232 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
233 struct nxctl_traffic_rule_inet_if {
234 	char rii_ifname[IFNAMSIZ];
235 	struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
236 	uint32_t rii_count;
237 	SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
238 };
239 
240 /* List of per-interface lists */
241 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
242 struct nxctl_traffic_rule_inet_storage {
243 	struct nxctl_traffic_rule_inet_if_head ris_if_list;
244 	uint32_t ris_count;
245 };
246 
247 /* Per-fd list kept at the nxctl */
248 SLIST_HEAD(nxctl_traffic_rule_head, nxctl_traffic_rule);
249 struct nxctl_traffic_rule_storage {
250 	struct nxctl_traffic_rule_head rs_list;
251 	uint32_t rs_count;
252 };
253 
254 static LCK_RW_DECLARE_ATTR(nxctl_traffic_rule_lock, &sk_lock_group, &sk_lock_attr);
255 #define NXTR_WLOCK() \
256     lck_rw_lock_exclusive(&nxctl_traffic_rule_lock)
257 #define NXTR_WUNLOCK() \
258     lck_rw_unlock_exclusive(&nxctl_traffic_rule_lock)
259 #define NXTR_RLOCK() \
260     lck_rw_lock_shared(&nxctl_traffic_rule_lock)
261 #define NXTR_RUNLOCK() \
262     lck_rw_unlock_shared(&nxctl_traffic_rule_lock)
263 
264 static struct nxctl_traffic_rule_type *find_traffic_rule_type(uint8_t type);
265 static void retain_traffic_rule(struct nxctl_traffic_rule *ntr);
266 static void release_traffic_rule(struct nxctl_traffic_rule *ntr);
267 static int remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
268     struct nxctl_traffic_rule **ntrp);
269 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
270     struct ifnet_ip_addr *a2);
271 static int notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags);
272 
273 #define NXCTL_TRAFFIC_RULE_TAG "com.apple.skywalk.nexus.traffic_rule"
274 static kern_allocation_name_t nxctl_traffic_rule_tag;
275 static struct nxctl_traffic_rule_type *inet_traffic_rule_type = NULL;
276 
277 /*
278  * If an interface attaches after rule(s) are added, this function is used
279  * retrieve the current rule count for that interface.
280  */
281 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)282 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
283 {
284 	struct nxctl_traffic_rule_inet_storage *rs;
285 	struct nxctl_traffic_rule_inet_if *rif;
286 	int err;
287 
288 	NXTR_RLOCK();
289 	rs = inet_traffic_rule_type->ntrt_storage;
290 	if (rs == NULL) {
291 		err = ENOENT;
292 		goto fail;
293 	}
294 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
295 		if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) == 0) {
296 			break;
297 		}
298 	}
299 	if (rif == NULL) {
300 		err = ENOENT;
301 		goto fail;
302 	}
303 	*count = rif->rii_count;
304 	NXTR_RUNLOCK();
305 	return 0;
306 fail:
307 	NXTR_RUNLOCK();
308 	return err;
309 }
310 
311 /*
312  * Used for finding the qset id associated with a traffic descriptor.
313  */
314 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)315 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
316     struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
317 {
318 	struct nxctl_traffic_rule_inet *__single ntri = NULL;
319 	struct nxctl_traffic_rule *__single ntr = NULL;
320 	int err;
321 
322 	NXTR_RLOCK();
323 	ASSERT(inet_traffic_rule_type != NULL);
324 	err = inet_traffic_rule_type->ntrt_find(inet_traffic_rule_type, ifname,
325 	    (struct ifnet_traffic_descriptor_common *)td, 0, &ntr);
326 	if (err != 0) {
327 		SK_ERR("rule find failed: %d", err);
328 		goto fail;
329 	}
330 	ntri = __container_of(ntr, struct nxctl_traffic_rule_inet, ntri_common);
331 	*qset_id = ntri->ntri_ra.ras_qset_id;
332 	NXTR_RUNLOCK();
333 	return 0;
334 fail:
335 	NXTR_RUNLOCK();
336 	return err;
337 }
338 
339 /*
340  * Based on flow_pkt_classify().
341  * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
342  */
343 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)344 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
345 {
346 	union {
347 		volatile struct ip *__indexable _iph;
348 		volatile struct ip6_hdr *__indexable _ip6;
349 	} _l3;
350 	#define iph _l3._iph
351 	#define ip6 _l3._ip6
352 	union {
353 		volatile struct tcphdr *_tcph;
354 		volatile struct udphdr *_udph;
355 	} _l4;
356 	#define tcph _l4._tcph
357 	#define udph _l4._udph
358 	uint8_t *pkt_buf, *l3_hdr;
359 	uint32_t bdlen, bdlim, bdoff, cls_len;
360 	size_t pkt_len;
361 	uint8_t ipv, l3hlen = 0; /* IP header length */
362 	uint16_t l3tlen = 0;     /* total length of IP packet */
363 	uint8_t l4hlen = 0;      /* TCP/UDP header length */
364 	uint16_t ulen = 0;       /* user data length */
365 	int err;
366 
367 	ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
368 	pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
369 
370 	MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
371 	cls_len = bdlim - bdoff;
372 	cls_len -= pkt->pkt_l2_len;
373 	cls_len = (uint32_t)MIN(cls_len, pkt_len);
374 	VERIFY(pkt_len >= cls_len);
375 	if (cls_len == 0) {
376 		SK_ERR("cls_len == 0");
377 		err = EINVAL;
378 		goto fail;
379 	}
380 	l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
381 	iph = (volatile struct ip *)(void *)l3_hdr;
382 	ipv = iph->ip_v;
383 
384 	switch (ipv) {
385 	case 4:
386 		if (cls_len < sizeof(struct ip)) {
387 			SK_ERR("cls_len < sizeof(struct ip) (%d < %d)",
388 			    cls_len, sizeof(struct ip));
389 			err = EINVAL;
390 			goto fail;
391 		}
392 		l3hlen = (uint8_t)(iph->ip_hl << 2);
393 		if (l3hlen < sizeof(struct ip)) {
394 			SK_ERR("l3hlen < sizeof(struct ip) (%d < %d)",
395 			    l3hlen, sizeof(struct ip));
396 			err = EINVAL;
397 			goto fail;
398 		}
399 		if (cls_len < l3hlen) {
400 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
401 			err = EINVAL;
402 			goto fail;
403 		}
404 		l3tlen = ntohs(iph->ip_len);
405 		if (l3tlen < l3hlen) {
406 			SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
407 			err = EINVAL;
408 			goto fail;
409 		}
410 		if (pkt_len < l3tlen) {
411 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
412 			err = EINVAL;
413 			goto fail;
414 		}
415 		td->inet_ipver = IPVERSION;
416 		td->inet_proto = iph->ip_p;
417 		bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
418 		    sizeof(iph->ip_src));
419 		bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
420 		    sizeof(iph->ip_dst));
421 		break;
422 	case 6:
423 		l3hlen = sizeof(struct ip6_hdr);
424 		if (cls_len < l3hlen) {
425 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
426 			err = EINVAL;
427 			goto fail;
428 		}
429 		l3tlen = l3hlen + ntohs(ip6->ip6_plen);
430 		if (pkt_len < l3tlen) {
431 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
432 			err = EINVAL;
433 			goto fail;
434 		}
435 		td->inet_ipver = IPV6_VERSION;
436 		td->inet_proto = ip6->ip6_nxt;
437 		bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
438 		    sizeof(ip6->ip6_src));
439 		bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
440 		    sizeof(ip6->ip6_dst));
441 		break;
442 	default:
443 		SK_ERR("ipv == %d", ipv);
444 		err = EINVAL;
445 		goto fail;
446 	}
447 	tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
448 	ulen = (l3tlen - l3hlen);
449 	if (td->inet_proto == IPPROTO_TCP) {
450 		if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
451 			SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
452 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*tcph),
453 			    ulen, sizeof(*tcph));
454 			err = EINVAL;
455 			goto fail;
456 		}
457 		l4hlen = (uint8_t)(tcph->th_off << 2);
458 		if (l4hlen < sizeof(*tcph)) {
459 			SK_ERR("l4hlen < sizeof(*tcph) (%d < %d)", l4hlen, sizeof(*tcph));
460 			err = EINVAL;
461 			goto fail;
462 		}
463 		if (l4hlen > ulen) {
464 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
465 			err = EINVAL;
466 			goto fail;
467 		}
468 		bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
469 		    sizeof(td->inet_lport));
470 		bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
471 		    sizeof(td->inet_rport));
472 	} else if (td->inet_proto == IPPROTO_UDP) {
473 		if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
474 			SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
475 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*udph),
476 			    ulen, sizeof(*udph));
477 			err = EINVAL;
478 			goto fail;
479 		}
480 		l4hlen = sizeof(*udph);
481 		if (l4hlen > ulen) {
482 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
483 			err = EINVAL;
484 			goto fail;
485 		}
486 		bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
487 		    sizeof(td->inet_lport));
488 		bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
489 		    sizeof(td->inet_rport));
490 	} else {
491 		err = ENOTSUP;
492 		goto fail;
493 	}
494 
495 	td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
496 	td->inet_common.itd_len = sizeof(*td);
497 	td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
498 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
499 	td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
500 	    IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
501 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
502 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
503 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
504 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
505 	return 0;
506 fail:
507 	DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
508 	    uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
509 	    int, err);
510 	bzero(td, sizeof(*td));
511 	return err;
512 	#undef iph
513 	#undef ip6
514 	#undef tcph
515 	#undef udph
516 }
517 
518 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)519 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
520     struct __kern_packet *pkt, uint64_t *qset_id)
521 {
522 	struct ifnet_traffic_descriptor_inet td;
523 	int err;
524 
525 	err = fill_inet_td(pkt, &td);
526 	if (err != 0) {
527 		return err;
528 	}
529 	return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
530 }
531 
532 void
nxctl_traffic_rule_init(void)533 nxctl_traffic_rule_init(void)
534 {
535 	ASSERT(nxctl_traffic_rule_tag == NULL);
536 	nxctl_traffic_rule_tag =
537 	    kern_allocation_name_allocate(NXCTL_TRAFFIC_RULE_TAG, 0);
538 	ASSERT(nxctl_traffic_rule_tag != NULL);
539 
540 	ASSERT(inet_traffic_rule_type == NULL);
541 	inet_traffic_rule_type =
542 	    find_traffic_rule_type(IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET);
543 	ASSERT(inet_traffic_rule_type != NULL);
544 }
545 
546 void
nxctl_traffic_rule_fini(void)547 nxctl_traffic_rule_fini(void)
548 {
549 	if (nxctl_traffic_rule_tag != NULL) {
550 		kern_allocation_name_release(nxctl_traffic_rule_tag);
551 		nxctl_traffic_rule_tag = NULL;
552 	}
553 	inet_traffic_rule_type = NULL;
554 }
555 
556 static struct ifnet_ip_addr v6_zeros_addr = {0};
557 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)558 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
559 {
560 	return memcmp(a1, a2, sizeof(*a1)) == 0;
561 }
562 
563 SK_NO_INLINE_ATTRIBUTE
564 static struct nxctl_traffic_rule_storage *
nxctl_traffic_rule_storage_create(void)565 nxctl_traffic_rule_storage_create(void)
566 {
567 	struct nxctl_traffic_rule_storage *rs;
568 
569 	rs = sk_alloc_type(struct nxctl_traffic_rule_storage,
570 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
571 	SLIST_INIT(&rs->rs_list);
572 	rs->rs_count = 0;
573 	return rs;
574 }
575 
576 SK_NO_INLINE_ATTRIBUTE
577 static void
nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage * rs)578 nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage *rs)
579 {
580 	ASSERT(rs->rs_count == 0);
581 	ASSERT(SLIST_EMPTY(&rs->rs_list));
582 	sk_free_type(struct nxctl_traffic_rule_storage, rs);
583 }
584 
585 /*
586  * This is meant to be called during closure of the nxctl's fd.
587  * This will cleanup all rules linked to this nxctl. Rules that
588  * are marked persistent won't be added to the nxctl list.
589  */
590 void
nxctl_traffic_rule_clean(struct nxctl * nxctl)591 nxctl_traffic_rule_clean(struct nxctl *nxctl)
592 {
593 	struct nxctl_traffic_rule_storage *rs;
594 	struct nxctl_traffic_rule *ntr, *next;
595 	int err;
596 
597 	lck_mtx_lock(&nxctl->nxctl_lock);
598 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
599 		lck_mtx_unlock(&nxctl->nxctl_lock);
600 		return;
601 	}
602 	ntr = SLIST_FIRST(&rs->rs_list);
603 	SLIST_INIT(&rs->rs_list);
604 	rs->rs_count = 0;
605 	nxctl_traffic_rule_storage_destroy(rs);
606 	nxctl->nxctl_traffic_rule_storage = NULL;
607 	lck_mtx_unlock(&nxctl->nxctl_lock);
608 
609 	while (ntr != NULL) {
610 		next = SLIST_NEXT(ntr, ntr_storage_link);
611 		/*
612 		 * Clearing the flag to tell remove_traffic_rule() not to
613 		 * remove from the nxctl list again.
614 		 */
615 		ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
616 
617 		/* Passing NULL because we already hold a reference */
618 		err = remove_traffic_rule(nxctl, ntr->ntr_uuid, NULL);
619 		if (err == 0) {
620 			(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
621 		}
622 		release_traffic_rule(ntr);
623 		ntr = next;
624 	}
625 }
626 
627 SK_NO_INLINE_ATTRIBUTE
628 static void
add_traffic_rule_to_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)629 add_traffic_rule_to_nxctl(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
630 {
631 	struct nxctl_traffic_rule_storage *rs;
632 
633 	lck_mtx_lock(&nxctl->nxctl_lock);
634 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
635 		rs = nxctl_traffic_rule_storage_create();
636 		nxctl->nxctl_traffic_rule_storage = rs;
637 	}
638 	ntr->ntr_flags |= NTR_FLAG_ON_NXCTL_LIST;
639 	retain_traffic_rule(ntr);
640 	SLIST_INSERT_HEAD(&rs->rs_list, ntr, ntr_storage_link);
641 	rs->rs_count++;
642 	lck_mtx_unlock(&nxctl->nxctl_lock);
643 }
644 
645 SK_NO_INLINE_ATTRIBUTE
646 static void
remove_traffic_rule_from_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)647 remove_traffic_rule_from_nxctl(struct nxctl *nxctl,
648     struct nxctl_traffic_rule *ntr)
649 {
650 	struct nxctl_traffic_rule_storage *rs;
651 
652 	lck_mtx_lock(&nxctl->nxctl_lock);
653 	if ((ntr->ntr_flags & NTR_FLAG_ON_NXCTL_LIST) == 0) {
654 		lck_mtx_unlock(&nxctl->nxctl_lock);
655 		return;
656 	}
657 	rs = nxctl->nxctl_traffic_rule_storage;
658 	SLIST_REMOVE(&rs->rs_list, ntr, nxctl_traffic_rule, ntr_storage_link);
659 	rs->rs_count--;
660 	ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
661 	release_traffic_rule(ntr);
662 	if (rs->rs_count == 0) {
663 		nxctl_traffic_rule_storage_destroy(rs);
664 		nxctl->nxctl_traffic_rule_storage = NULL;
665 	}
666 	lck_mtx_unlock(&nxctl->nxctl_lock);
667 }
668 
669 static int
inet_traffic_rule_validate(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)670 inet_traffic_rule_validate(struct nxctl_traffic_rule_type *type,
671     const char *ifname,
672     struct ifnet_traffic_descriptor_common *td,
673     struct ifnet_traffic_rule_action *ra)
674 {
675 #pragma unused(type)
676 	char buf[IFNAMSIZ];
677 	int unit, i;
678 	struct ifnet_traffic_descriptor_inet *tdi;
679 	uint8_t mask = 0, ipver, proto;
680 
681 	if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
682 		SK_ERR("invalid ifname: %s", ifname);
683 		return EINVAL;
684 	}
685 	if (td->itd_len != sizeof(*tdi)) {
686 		SK_ERR("invalid td len: expected %d, actual %d",
687 		    sizeof(*tdi), td->itd_len);
688 		return EINVAL;
689 	}
690 	if (td->itd_flags == 0 ||
691 	    (td->itd_flags &
692 	    ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
693 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
694 		SK_ERR("invalid td flags: 0x%x", td->itd_flags);
695 		return EINVAL;
696 	}
697 	tdi = (struct ifnet_traffic_descriptor_inet *)td;
698 	for (i = 0; i < NINETRULEMASKS; i++) {
699 		if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
700 			mask = tdi->inet_mask;
701 			break;
702 		}
703 	}
704 	if (mask == 0) {
705 		SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
706 		return EINVAL;
707 	}
708 	ipver = tdi->inet_ipver;
709 	if (ipver != IPVERSION && ipver != IPV6_VERSION) {
710 		SK_ERR("invalid inet ipver: 0x%x", ipver);
711 		return EINVAL;
712 	}
713 	proto = tdi->inet_proto;
714 	if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
715 		SK_ERR("invalid inet proto: %d", proto);
716 		return EINVAL;
717 	}
718 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
719 		if (ipver == IPVERSION) {
720 			if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
721 				SK_ERR("inet laddr v4 cannot be unspecified");
722 				return EINVAL;
723 			}
724 		} else {
725 			if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
726 				SK_ERR("inet laddr v4 cannot be unspecified");
727 				return EINVAL;
728 			}
729 		}
730 	}
731 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
732 		if (ipver == IPVERSION) {
733 			if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
734 				SK_ERR("inet raddr v6 cannot be unspecified");
735 				return EINVAL;
736 			}
737 		} else {
738 			if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
739 				SK_ERR("inet raddr v6 cannot be unspecified");
740 				return EINVAL;
741 			}
742 		}
743 	}
744 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
745 		if (tdi->inet_lport == 0) {
746 			SK_ERR("inet lport cannot be unspecified");
747 			return EINVAL;
748 		}
749 	}
750 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
751 		if (tdi->inet_rport == 0) {
752 			SK_ERR("inet rport cannot be unspecified");
753 			return EINVAL;
754 		}
755 	}
756 	if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
757 		SK_ERR("invalid ra len: expected %d, actual %d",
758 		    sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
759 		return EINVAL;
760 	}
761 	return 0;
762 }
763 
764 SK_NO_INLINE_ATTRIBUTE
765 static struct nxctl_traffic_rule_inet_storage *
inet_traffic_rule_storage_create(void)766 inet_traffic_rule_storage_create(void)
767 {
768 	struct nxctl_traffic_rule_inet_storage *rs;
769 
770 	rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
771 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
772 	SLIST_INIT(&rs->ris_if_list);
773 	rs->ris_count = 0;
774 	return rs;
775 }
776 
777 SK_NO_INLINE_ATTRIBUTE
778 static void
inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage * rs)779 inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage *rs)
780 {
781 	ASSERT(rs->ris_count == 0);
782 	ASSERT(SLIST_EMPTY(&rs->ris_if_list));
783 	sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
784 }
785 
786 SK_NO_INLINE_ATTRIBUTE
787 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)788 inet_traffic_rule_if_create(const char *ifname)
789 {
790 	struct nxctl_traffic_rule_inet_if *rif;
791 	int i;
792 
793 	rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
794 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
795 	for (i = 0; i < NINETRULEMASKS; i++) {
796 		SLIST_INIT(&rif->rii_lists[i]);
797 	}
798 	strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
799 	rif->rii_count = 0;
800 	return rif;
801 }
802 
803 SK_NO_INLINE_ATTRIBUTE
804 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)805 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
806 {
807 	int i;
808 
809 	for (i = 0; i < NINETRULEMASKS; i++) {
810 		ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
811 	}
812 	ASSERT(rif->rii_count == 0);
813 	sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
814 }
815 
816 SK_NO_INLINE_ATTRIBUTE
817 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)818 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
819     uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
820 {
821 	struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
822 	struct ifnet_traffic_descriptor_inet *tdi0;
823 	uint8_t mask;
824 	boolean_t exact;
825 
826 	VERIFY(strlcmp(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname)) == 0);
827 	tdi0 = &ntri->ntri_td;
828 
829 	exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
830 	mask = tdi0->inet_mask & tdi->inet_mask;
831 	if (exact) {
832 		ASSERT(tdi0->inet_mask == tdi->inet_mask);
833 	}
834 	ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
835 	if (tdi0->inet_ipver != tdi->inet_ipver) {
836 		DTRACE_SKYWALK2(ipver__mismatch,
837 		    uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
838 		return FALSE;
839 	}
840 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0 &&
841 	    tdi0->inet_proto != tdi->inet_proto) {
842 		DTRACE_SKYWALK2(proto__mismatch,
843 		    uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
844 		return FALSE;
845 	}
846 	if (tdi0->inet_ipver == IPVERSION) {
847 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
848 		    tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
849 			DTRACE_SKYWALK2(v4laddr__mismatch,
850 			    in_addr_t, tdi0->inet_laddr.iia_v4addr,
851 			    in_addr_t, tdi->inet_laddr.iia_v4addr);
852 			return FALSE;
853 		}
854 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
855 		    tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
856 			DTRACE_SKYWALK2(v4raddr__mismatch,
857 			    in_addr_t, tdi0->inet_raddr.iia_v4addr,
858 			    in_addr_t, tdi->inet_raddr.iia_v4addr);
859 			return FALSE;
860 		}
861 	} else {
862 		ASSERT(tdi0->inet_ipver == IPV6_VERSION);
863 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
864 		    !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
865 			DTRACE_SKYWALK2(v6laddr__mismatch,
866 			    struct in6_addr *, &tdi0->inet_laddr,
867 			    struct in6_addr *, &tdi->inet_laddr);
868 			return FALSE;
869 		}
870 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
871 		    !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
872 			DTRACE_SKYWALK2(v6raddr__mismatch,
873 			    struct in6_addr *, &tdi0->inet_raddr,
874 			    struct in6_addr *, &tdi->inet_raddr);
875 			return FALSE;
876 		}
877 	}
878 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
879 	    tdi0->inet_lport != tdi->inet_lport) {
880 		DTRACE_SKYWALK2(lport__mismatch,
881 		    uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
882 		return FALSE;
883 	}
884 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
885 	    tdi0->inet_rport != tdi->inet_rport) {
886 		DTRACE_SKYWALK2(rport__mismatch,
887 		    uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
888 		return FALSE;
889 	}
890 	return TRUE;
891 }
892 
893 static int
inet_traffic_rule_find(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)894 inet_traffic_rule_find(struct nxctl_traffic_rule_type *type, const char *ifname,
895     struct ifnet_traffic_descriptor_common *td, uint32_t flags,
896     struct nxctl_traffic_rule **ntrp)
897 {
898 	struct nxctl_traffic_rule_inet *ntri = NULL;
899 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
900 	struct nxctl_traffic_rule_inet_if *rif;
901 	struct ifnet_traffic_descriptor_inet *tdi =
902 	    (struct ifnet_traffic_descriptor_inet *)td;
903 	int i;
904 
905 	if (rs == NULL) {
906 		return ENOENT;
907 	}
908 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
909 		if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) != 0) {
910 			continue;
911 		}
912 		for (i = 0; i < NINETRULEMASKS; i++) {
913 			if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
914 			    tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
915 				continue;
916 			}
917 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
918 				if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
919 					*ntrp = (struct nxctl_traffic_rule *)ntri;
920 					return 0;
921 				}
922 			}
923 		}
924 	}
925 	return ENOENT;
926 }
927 
928 static int
inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type * type,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)929 inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type *type,
930     uuid_t uuid, struct nxctl_traffic_rule **ntrp)
931 {
932 	struct nxctl_traffic_rule_inet *ntri;
933 	struct nxctl_traffic_rule *ntr;
934 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
935 	struct nxctl_traffic_rule_inet_if *rif;
936 	int i;
937 
938 	if (rs == NULL) {
939 		return ENOENT;
940 	}
941 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
942 		for (i = 0; i < NINETRULEMASKS; i++) {
943 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
944 				ntr = &ntri->ntri_common;
945 				if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
946 					*ntrp = ntr;
947 					return 0;
948 				}
949 			}
950 		}
951 	}
952 	return ENOENT;
953 }
954 
955 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)956 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
957 {
958 	struct ifnet *ifp;
959 
960 	ifp = ifunit_ref(ifname);
961 	if (ifp == NULL) {
962 		DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
963 		return;
964 	}
965 	ifnet_update_traffic_rule_count(ifp, count);
966 	ifnet_decr_iorefcnt(ifp);
967 }
968 
969 static void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)970 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
971 {
972 	struct nxctl_traffic_rule_type *type = ntr->ntr_type;
973 	struct nxctl_traffic_rule_inet_storage *rs;
974 	struct nxctl_traffic_rule_inet_if *rif;
975 	struct nxctl_traffic_rule_inet *ntri =
976 	    (struct nxctl_traffic_rule_inet *)ntr;
977 	struct nxctl_traffic_rule_inet_head *list = NULL;
978 	int i;
979 	char *__null_terminated ntr_ifname = NULL;
980 	char *__null_terminated rii_ifname = NULL;
981 
982 	if ((rs = type->ntrt_storage) == NULL) {
983 		rs = inet_traffic_rule_storage_create();
984 		type->ntrt_storage = rs;
985 	}
986 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
987 		if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
988 			break;
989 		}
990 	}
991 	if (rif == NULL) {
992 		ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
993 		rif = inet_traffic_rule_if_create(ntr_ifname);
994 		SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
995 	}
996 	for (i = 0; i < NINETRULEMASKS; i++) {
997 		if (ntri->ntri_td.inet_mask ==
998 		    nxctl_inet_traffic_rule_masks[i]) {
999 			list = &rif->rii_lists[i];
1000 			break;
1001 		}
1002 	}
1003 	retain_traffic_rule(ntr);
1004 	ASSERT(list != NULL);
1005 	SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
1006 	/* per-interface count */
1007 	rif->rii_count++;
1008 	rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
1009 	inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
1010 
1011 	/* global count */
1012 	rs->ris_count++;
1013 }
1014 
1015 static void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)1016 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
1017 {
1018 	struct nxctl_traffic_rule_inet_storage *rs;
1019 	struct nxctl_traffic_rule_inet_if *rif;
1020 	struct nxctl_traffic_rule_inet *ntri =
1021 	    (struct nxctl_traffic_rule_inet *)ntr;
1022 	struct nxctl_traffic_rule_inet_head *list = NULL;
1023 	struct nxctl_traffic_rule_type *type;
1024 	int i;
1025 	char *__null_terminated rii_ifname = NULL;
1026 
1027 	type = ntr->ntr_type;
1028 	rs = type->ntrt_storage;
1029 	ASSERT(rs != NULL);
1030 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1031 		if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
1032 			break;
1033 		}
1034 	}
1035 	ASSERT(rif != NULL);
1036 	for (i = 0; i < NINETRULEMASKS; i++) {
1037 		if (ntri->ntri_td.inet_mask ==
1038 		    nxctl_inet_traffic_rule_masks[i]) {
1039 			list = &rif->rii_lists[i];
1040 			break;
1041 		}
1042 	}
1043 	ASSERT(list != NULL);
1044 	SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
1045 	rif->rii_count--;
1046 	rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
1047 	inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
1048 
1049 	rs->ris_count--;
1050 	release_traffic_rule(ntr);
1051 
1052 	if (rif->rii_count == 0) {
1053 		SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
1054 		inet_traffic_rule_if_destroy(rif);
1055 	}
1056 	if (rs->ris_count == 0) {
1057 		type->ntrt_storage = NULL;
1058 		inet_traffic_rule_storage_destroy(rs);
1059 	}
1060 }
1061 
1062 /*
1063  * XXX
1064  * This may need additional changes to ensure safety against detach/attach.
1065  * This is not an issue for the first consumer of llink interfaces, cellular,
1066  * which does not detach.
1067  */
1068 static int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)1069 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
1070 {
1071 	struct ifnet *ifp;
1072 	struct nx_netif *nif;
1073 	struct netif_qset *__single qset = NULL;
1074 	struct nxctl_traffic_rule_inet *ntri;
1075 	int err = 0;
1076 	char *__null_terminated ntr_ifname = NULL;
1077 
1078 	ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
1079 	ifp = ifunit_ref(ntr_ifname);
1080 	if (ifp == NULL) {
1081 		DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
1082 		err = ENXIO;
1083 		goto done;
1084 	}
1085 	nif = NA(ifp)->nifna_netif;
1086 	if (!NX_LLINK_PROV(nif->nif_nx)) {
1087 		DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
1088 		err = ENOTSUP;
1089 		goto done;
1090 	}
1091 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1092 	qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
1093 	err = nx_netif_notify_steering_info(nif, qset,
1094 	    (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
1095 	    ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
1096 done:
1097 	if (qset != NULL) {
1098 		nx_netif_qset_release(&qset);
1099 	}
1100 	if (ifp != NULL) {
1101 		ifnet_decr_iorefcnt(ifp);
1102 	}
1103 	return err;
1104 }
1105 
1106 static int
inet_traffic_rule_create(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1107 inet_traffic_rule_create(struct nxctl_traffic_rule_type *type,
1108     const char *ifname, struct ifnet_traffic_descriptor_common *td,
1109     struct ifnet_traffic_rule_action *ra, uint32_t flags,
1110     struct nxctl_traffic_rule **ntrp)
1111 {
1112 	struct nxctl_traffic_rule_inet *ntri;
1113 	struct nxctl_traffic_rule *ntr;
1114 	struct ifnet_traffic_descriptor_inet *tdi;
1115 	struct ifnet_traffic_rule_action_steer *ras;
1116 
1117 	ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
1118 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
1119 	ntr = &ntri->ntri_common;
1120 
1121 	ntr->ntr_type = type;
1122 	ntr->ntr_flags = flags;
1123 	uuid_generate(ntr->ntr_uuid);
1124 	os_ref_init(&ntr->ntr_refcnt, NULL);
1125 
1126 	strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
1127 	proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
1128 
1129 	tdi = __container_of(td, struct ifnet_traffic_descriptor_inet, inet_common);
1130 	ras = __container_of(ra, struct ifnet_traffic_rule_action_steer, ras_common);
1131 	bcopy(tdi, &ntri->ntri_td, sizeof(ntri->ntri_td));
1132 	bcopy(ras, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
1133 
1134 	*ntrp = ntr;
1135 	return 0;
1136 }
1137 
1138 static void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)1139 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
1140 {
1141 	struct nxctl_traffic_rule_inet *ntri;
1142 
1143 	ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
1144 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1145 	sk_free_type(struct nxctl_traffic_rule_inet, ntri);
1146 }
1147 
1148 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)1149 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
1150     struct nxctl_traffic_rule_inet_iocinfo *info)
1151 {
1152 	struct nxctl_traffic_rule *ntr;
1153 	struct nxctl_traffic_rule_generic_iocinfo *ginfo;
1154 
1155 	bzero(info, sizeof(*info));
1156 	ntr = &ntri->ntri_common;
1157 	ginfo = &info->tri_common;
1158 	_CASSERT(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
1159 	_CASSERT(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
1160 	uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
1161 	strbufcpy(ginfo->trg_procname, ntr->ntr_procname);
1162 	strbufcpy(ginfo->trg_ifname, ntr->ntr_ifname);
1163 	bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
1164 	bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
1165 }
1166 
1167 static int
inet_traffic_rule_get_all(struct nxctl_traffic_rule_type * type,uint32_t size,uint32_t * count,user_addr_t uaddr)1168 inet_traffic_rule_get_all(struct nxctl_traffic_rule_type *type, uint32_t size,
1169     uint32_t *count, user_addr_t uaddr)
1170 {
1171 	struct nxctl_traffic_rule_inet *ntri = NULL;
1172 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
1173 	struct nxctl_traffic_rule_inet_if *rif;
1174 	struct nxctl_traffic_rule_inet_iocinfo info;
1175 	int i, err;
1176 
1177 	if (size != sizeof(info)) {
1178 		SK_ERR("size: actual %d, expected %d", size, sizeof(info));
1179 		return EINVAL;
1180 	}
1181 	if (rs == NULL) {
1182 		*count = 0;
1183 		return 0;
1184 	}
1185 	if (*count < rs->ris_count) {
1186 		SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
1187 		return ENOBUFS;
1188 	}
1189 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1190 		for (i = 0; i < NINETRULEMASKS; i++) {
1191 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
1192 				convert_ntri_to_iocinfo(ntri, &info);
1193 				err = copyout(&info, uaddr, sizeof(info));
1194 				if (err != 0) {
1195 					SK_ERR("copyout failed: %d", err);
1196 					return err;
1197 				}
1198 				uaddr += sizeof(info);
1199 			}
1200 		}
1201 	}
1202 	*count = rs->ris_count;
1203 	return 0;
1204 }
1205 
1206 SK_NO_INLINE_ATTRIBUTE
1207 static void
retain_traffic_rule(struct nxctl_traffic_rule * ntr)1208 retain_traffic_rule(struct nxctl_traffic_rule *ntr)
1209 {
1210 #if (DEVELOPMENT || DEBUG)
1211 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1212 	DTRACE_SKYWALK2(ntr__retain, struct nxctl_traffic_rule *, ntr,
1213 	    os_ref_count_t, count);
1214 #endif
1215 	os_ref_retain(&ntr->ntr_refcnt);
1216 }
1217 
1218 SK_NO_INLINE_ATTRIBUTE
1219 static void
release_traffic_rule(struct nxctl_traffic_rule * ntr)1220 release_traffic_rule(struct nxctl_traffic_rule *ntr)
1221 {
1222 #if (DEVELOPMENT || DEBUG)
1223 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1224 	DTRACE_SKYWALK2(ntr__release, struct nxctl_traffic_rule *, ntr,
1225 	    os_ref_count_t, count);
1226 #endif
1227 	if (os_ref_release(&ntr->ntr_refcnt) == 0) {
1228 		ntr->ntr_type->ntrt_destroy(ntr);
1229 	}
1230 }
1231 
1232 SK_NO_INLINE_ATTRIBUTE
1233 static int
notify_traffic_rule(struct nxctl_traffic_rule * ntr,uint32_t flags)1234 notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags)
1235 {
1236 	return ntr->ntr_type->ntrt_notify(ntr, flags);
1237 }
1238 
1239 static void
link_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1240 link_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1241 {
1242 	/*
1243 	 * The persist flag means: do not clean up rule upon nxctl fd close.
1244 	 * This means we only add the rule to the nxctl list if persist
1245 	 * is not set.
1246 	 */
1247 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1248 		add_traffic_rule_to_nxctl(nxctl, ntr);
1249 	}
1250 	ntr->ntr_type->ntrt_link(ntr);
1251 }
1252 
1253 static void
unlink_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1254 unlink_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1255 {
1256 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1257 		remove_traffic_rule_from_nxctl(nxctl, ntr);
1258 	}
1259 	ntr->ntr_type->ntrt_unlink(ntr);
1260 }
1261 
1262 static int
find_traffic_rule_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1263 find_traffic_rule_by_uuid(uuid_t uuid, struct nxctl_traffic_rule **ntrp)
1264 {
1265 	int i, err;
1266 	struct nxctl_traffic_rule_type *ntrt;
1267 	struct nxctl_traffic_rule *__single ntr = NULL;
1268 
1269 	for (i = 0; i < NRULETYPES; i++) {
1270 		ntrt = &nxctl_rule_types[i];
1271 		err = ntrt->ntrt_find_by_uuid(ntrt, uuid, &ntr);
1272 		if (err == 0) {
1273 			ASSERT(ntr != NULL);
1274 			*ntrp = ntr;
1275 			return 0;
1276 		}
1277 	}
1278 	return ENOENT;
1279 }
1280 
1281 static struct nxctl_traffic_rule_type *
find_traffic_rule_type(uint8_t type)1282 find_traffic_rule_type(uint8_t type)
1283 {
1284 	int i;
1285 	struct nxctl_traffic_rule_type *ntrt;
1286 
1287 	for (i = 0; i < NRULETYPES; i++) {
1288 		ntrt = &nxctl_rule_types[i];
1289 		if (ntrt->ntrt_type == type) {
1290 			return ntrt;
1291 		}
1292 	}
1293 	return NULL;
1294 }
1295 
1296 SK_NO_INLINE_ATTRIBUTE
1297 static int
add_traffic_rule(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1298 add_traffic_rule(struct nxctl *nxctl, const char *ifname,
1299     struct ifnet_traffic_descriptor_common *td,
1300     struct ifnet_traffic_rule_action *ra,
1301     uint32_t flags,
1302     struct nxctl_traffic_rule **ntrp)
1303 {
1304 	struct nxctl_traffic_rule_type *type = NULL;
1305 	struct nxctl_traffic_rule *__single ntr = NULL;
1306 	int err;
1307 
1308 	NXTR_WLOCK();
1309 	type = find_traffic_rule_type(td->itd_type);
1310 	if (type == NULL) {
1311 		SK_ERR("rule type %x not found", td->itd_type);
1312 		err = EINVAL;
1313 		goto fail;
1314 	}
1315 	err = type->ntrt_validate(type, ifname, td, ra);
1316 	if (err != 0) {
1317 		SK_ERR("rule validate failed: %d", err);
1318 		goto fail;
1319 	}
1320 	err = type->ntrt_find(type, ifname, td, NTR_FIND_FLAG_EXACT, &ntr);
1321 	if (err == 0) {
1322 		SK_ERR("rule already exists");
1323 		ASSERT(ntr != NULL);
1324 		err = EEXIST;
1325 		goto fail;
1326 	} else if (err != ENOENT) {
1327 		SK_ERR("rule find failed: %d", err);
1328 		goto fail;
1329 	}
1330 	err = type->ntrt_create(type, ifname, td, ra, flags, &ntr);
1331 	if (err != 0) {
1332 		SK_ERR("rule create failed: %d", err);
1333 		goto fail;
1334 	}
1335 	link_traffic_rule(nxctl, ntr);
1336 	if (ntrp != NULL) {
1337 		retain_traffic_rule(ntr);
1338 		*ntrp = ntr;
1339 	}
1340 	NXTR_WUNLOCK();
1341 	return 0;
1342 fail:
1343 	NXTR_WUNLOCK();
1344 	return err;
1345 }
1346 
1347 
1348 SK_NO_INLINE_ATTRIBUTE
1349 static int
remove_traffic_rule(struct nxctl * nxctl,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1350 remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
1351     struct nxctl_traffic_rule **ntrp)
1352 {
1353 	struct nxctl_traffic_rule *__single ntr;
1354 	int err;
1355 
1356 	NXTR_WLOCK();
1357 	err = find_traffic_rule_by_uuid(uuid, &ntr);
1358 	if (err != 0) {
1359 		SK_ERR("traffic rule not found");
1360 		NXTR_WUNLOCK();
1361 		return err;
1362 	}
1363 	if (ntrp != NULL) {
1364 		retain_traffic_rule(ntr);
1365 		*ntrp = ntr;
1366 	}
1367 	unlink_traffic_rule(nxctl, ntr);
1368 	/* release initial reference */
1369 	release_traffic_rule(ntr);
1370 	NXTR_WUNLOCK();
1371 	return 0;
1372 }
1373 
1374 static uint32_t
convert_traffic_rule_ioc_flags(uint32_t flags)1375 convert_traffic_rule_ioc_flags(uint32_t flags)
1376 {
1377 	uint32_t f = 0;
1378 
1379 	if ((flags & NXIOC_ADD_TRAFFIC_RULE_FLAG_PERSIST) != 0) {
1380 		f |= NTR_FLAG_PERSIST;
1381 	}
1382 	return f;
1383 }
1384 
1385 SK_NO_INLINE_ATTRIBUTE
1386 static int
add_traffic_rule_generic(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,uuid_t * uuid)1387 add_traffic_rule_generic(struct nxctl *nxctl, const char *ifname,
1388     struct ifnet_traffic_descriptor_common *td,
1389     struct ifnet_traffic_rule_action *ra, uint32_t flags, uuid_t *uuid)
1390 {
1391 	struct nxctl_traffic_rule *__single ntr;
1392 	int err;
1393 
1394 	err = add_traffic_rule(nxctl, ifname, td, ra, flags, &ntr);
1395 	if (err != 0) {
1396 		return err;
1397 	}
1398 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_ADD);
1399 	uuid_copy(*uuid, ntr->ntr_uuid);
1400 	release_traffic_rule(ntr);
1401 	return 0;
1402 }
1403 
1404 int
nxioctl_add_traffic_rule_inet(struct nxctl * nxctl,caddr_t data,proc_t procp)1405 nxioctl_add_traffic_rule_inet(struct nxctl *nxctl, caddr_t data, proc_t procp)
1406 {
1407 #pragma unused(procp)
1408 	struct nxctl_add_traffic_rule_inet_iocargs *args =
1409 	    (struct nxctl_add_traffic_rule_inet_iocargs *)(void *)data;
1410 	char *__null_terminated atri_ifname = NULL;
1411 
1412 	atri_ifname = __unsafe_null_terminated_from_indexable(args->atri_ifname);
1413 
1414 	return add_traffic_rule_generic(nxctl, atri_ifname,
1415 	           (struct ifnet_traffic_descriptor_common *)&args->atri_td,
1416 	           (struct ifnet_traffic_rule_action *)&args->atri_ra,
1417 	           convert_traffic_rule_ioc_flags(args->atri_flags),
1418 	           &args->atri_uuid);
1419 }
1420 
1421 int
nxioctl_remove_traffic_rule(struct nxctl * nxctl,caddr_t data,proc_t procp)1422 nxioctl_remove_traffic_rule(struct nxctl *nxctl, caddr_t data, proc_t procp)
1423 {
1424 #pragma unused(procp)
1425 	struct nxctl_remove_traffic_rule_iocargs *args =
1426 	    (struct nxctl_remove_traffic_rule_iocargs *)(void *)data;
1427 	struct nxctl_traffic_rule *__single ntr;
1428 	int err;
1429 
1430 	err = remove_traffic_rule(nxctl, args->rtr_uuid, &ntr);
1431 	if (err != 0) {
1432 		return err;
1433 	}
1434 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
1435 	release_traffic_rule(ntr);
1436 	return 0;
1437 }
1438 
1439 int
nxioctl_get_traffic_rules(struct nxctl * nxctl,caddr_t data,proc_t procp)1440 nxioctl_get_traffic_rules(struct nxctl *nxctl, caddr_t data, proc_t procp)
1441 {
1442 #pragma unused(nxctl)
1443 	struct nxctl_get_traffic_rules_iocargs *args =
1444 	    (struct nxctl_get_traffic_rules_iocargs *)(void *)data;
1445 	struct nxctl_traffic_rule_type *type;
1446 	user_addr_t uaddr;
1447 	int err;
1448 
1449 	NXTR_RLOCK();
1450 	type = find_traffic_rule_type(args->gtr_type);
1451 	if (type == NULL) {
1452 		SK_ERR("rule type %x not found", args->gtr_type);
1453 		err = EINVAL;
1454 		goto fail;
1455 	}
1456 	uaddr = proc_is64bit(procp) ? args->gtr_buf64 :
1457 	    CAST_USER_ADDR_T(args->gtr_buf);
1458 	err = type->ntrt_get_all(type, args->gtr_size, &args->gtr_count, uaddr);
1459 	if (err != 0) {
1460 		goto fail;
1461 	}
1462 	NXTR_RUNLOCK();
1463 	return 0;
1464 fail:
1465 	NXTR_RUNLOCK();
1466 	return err;
1467 }
1468