xref: /xnu-11417.121.6/bsd/skywalk/nexus/nexus_traffic_rule.c (revision a1e26a70f38d1d7daa7b49b258e2f8538ad81650)
1 /*
2  * Copyright (c) 2022 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 #include <skywalk/os_skywalk_private.h>
29 #include <skywalk/nexus/netif/nx_netif.h>
30 #include <netinet/ip.h>
31 #include <netinet/ip6.h>
32 #include <netinet/tcp.h>
33 #include <netinet/udp.h>
34 #include <sys/sdt.h>
35 
36 /*
37  * Implementation of nexus traffic rules APIs.
38  */
39 
40 struct nxctl_traffic_rule_type;
41 struct nxctl_traffic_rule;
42 
43 /*
44  * These callbacks need to be implemented for each rule type.
45  */
46 
47 /* Validate user provided parameters. */
48 typedef int (nxctl_traffic_rule_validate_cb_t)(
49 	struct nxctl_traffic_rule_type *type,
50 	const char *ifname,
51 	struct ifnet_traffic_descriptor_common *td,
52 	struct ifnet_traffic_rule_action *ra);
53 /*
54  * Each rule type has its own global structure for storing rules.
55  * These callbacks access this global structure.
56  */
57 #define NTR_FIND_FLAG_EXACT 0x0001
58 typedef int (nxctl_traffic_rule_find_cb_t)(
59 	struct nxctl_traffic_rule_type *type,
60 	const char *ifname,
61 	struct ifnet_traffic_descriptor_common *td,
62 	uint32_t flags,
63 	struct nxctl_traffic_rule **ntrp);
64 
65 typedef int (nxctl_traffic_rule_find_by_uuid_cb_t)(
66 	struct nxctl_traffic_rule_type *type,
67 	uuid_t uuid,
68 	struct nxctl_traffic_rule **ntrp);
69 
70 typedef void (nxctl_traffic_rule_link_cb_t)(
71 	struct nxctl_traffic_rule *ntr);
72 
73 typedef void (nxctl_traffic_rule_unlink_cb_t)(
74 	struct nxctl_traffic_rule *ntr);
75 
76 /*
77  * Notifies lower layers of the addition/removal of a rule.
78  * This is called outside of nxctl_traffic_rule_lock to avoid potential
79  * locking issues.
80  */
81 #define NTR_NOTIFY_FLAG_ADD 0x0001
82 #define NTR_NOTIFY_FLAG_REMOVE 0x0002
83 typedef int (nxctl_traffic_rule_notify_cb_t)(
84 	struct nxctl_traffic_rule *ntr,
85 	uint32_t flags);
86 
87 /*
88  * Create/Destroy callbacks for a rule type.
89  */
90 typedef int (nxctl_traffic_rule_create_cb_t)(
91 	struct nxctl_traffic_rule_type *type,
92 	const char *ifname,
93 	struct ifnet_traffic_descriptor_common *td,
94 	struct ifnet_traffic_rule_action *ra,
95 	uint32_t flags,
96 	struct nxctl_traffic_rule **ntrp);
97 
98 typedef void (nxctl_traffic_rule_destroy_cb_t)(
99 	struct nxctl_traffic_rule *ntr);
100 
101 /*
102  * This is used for copying all rules for a type (including generic
103  * and type-specific info) to userspace.
104  */
105 typedef int (nxctl_traffic_rule_get_all_cb_t)(
106 	struct nxctl_traffic_rule_type *type,
107 	uint32_t size,
108 	uint32_t *count,
109 	user_addr_t uaddr);
110 
111 struct nxctl_traffic_rule_type {
112 	uint8_t ntrt_type;
113 	nxctl_traffic_rule_validate_cb_t *ntrt_validate;
114 	nxctl_traffic_rule_find_cb_t *ntrt_find;
115 	nxctl_traffic_rule_find_by_uuid_cb_t *ntrt_find_by_uuid;
116 	nxctl_traffic_rule_link_cb_t *ntrt_link;
117 	nxctl_traffic_rule_unlink_cb_t *ntrt_unlink;
118 	nxctl_traffic_rule_notify_cb_t *ntrt_notify;
119 	nxctl_traffic_rule_create_cb_t *ntrt_create;
120 	nxctl_traffic_rule_destroy_cb_t *ntrt_destroy;
121 	nxctl_traffic_rule_get_all_cb_t *ntrt_get_all;
122 	/*
123 	 * -fbounds-safety: Why is this void *? All usage of this was for
124 	 * struct nxctl_traffic_rule_inet_storage *
125 	 */
126 	struct nxctl_traffic_rule_inet_storage *ntrt_storage;
127 };
128 
129 static nxctl_traffic_rule_validate_cb_t inet_traffic_rule_validate;
130 static nxctl_traffic_rule_find_cb_t inet_traffic_rule_find;
131 static nxctl_traffic_rule_find_by_uuid_cb_t inet_traffic_rule_find_by_uuid;
132 static nxctl_traffic_rule_link_cb_t inet_traffic_rule_link;
133 static nxctl_traffic_rule_unlink_cb_t inet_traffic_rule_unlink;
134 static nxctl_traffic_rule_notify_cb_t inet_traffic_rule_notify;
135 static nxctl_traffic_rule_create_cb_t inet_traffic_rule_create;
136 static nxctl_traffic_rule_destroy_cb_t inet_traffic_rule_destroy;
137 static nxctl_traffic_rule_get_all_cb_t inet_traffic_rule_get_all;
138 
139 static struct nxctl_traffic_rule_type nxctl_rule_types[] = {
140 	{
141 		.ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET,
142 		.ntrt_validate = inet_traffic_rule_validate,
143 		.ntrt_find = inet_traffic_rule_find,
144 		.ntrt_find_by_uuid = inet_traffic_rule_find_by_uuid,
145 		.ntrt_link = inet_traffic_rule_link,
146 		.ntrt_unlink = inet_traffic_rule_unlink,
147 		.ntrt_notify = inet_traffic_rule_notify,
148 		.ntrt_create = inet_traffic_rule_create,
149 		.ntrt_destroy = inet_traffic_rule_destroy,
150 		.ntrt_get_all = inet_traffic_rule_get_all,
151 	},
152 };
153 #define NRULETYPES \
154     (sizeof(nxctl_rule_types)/sizeof(struct nxctl_traffic_rule_type))
155 
156 /*
157  * Generic traffic rule.
158  * Contains fields common to all traffic rules.
159  */
160 #define NTR_FLAG_PERSIST 0x0001
161 #define NTR_FLAG_ON_NXCTL_LIST 0x0002
162 struct nxctl_traffic_rule {
163 	struct nxctl_traffic_rule_type *ntr_type;
164 	uint32_t ntr_flags;
165 	os_refcnt_t ntr_refcnt;
166 	uuid_t ntr_uuid;
167 	char ntr_procname[NTR_PROCNAME_SZ];
168 	char ntr_ifname[IFNAMSIZ];
169 	SLIST_ENTRY(nxctl_traffic_rule) ntr_storage_link;
170 };
171 
172 /*
173  * Inet-specific traffic rule.
174  */
175 struct nxctl_traffic_rule_inet {
176 	struct nxctl_traffic_rule ntri_common;
177 	SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
178 	struct ifnet_traffic_descriptor_inet ntri_td;
179 	struct ifnet_traffic_rule_action_steer ntri_ra;
180 };
181 
182 /*
183  * Currently supported tuple types.
184  */
185 #define ITDBIT(set, bit) (((set) != 0) ? (bit) : 0)
186 #define ITRM(proto, laddr, raddr, lport, rport) \
187 	(IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER | \
188 	ITDBIT(proto, IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) | \
189 	ITDBIT(laddr, IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) | \
190 	ITDBIT(raddr, IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) | \
191 	ITDBIT(lport, IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) | \
192 	ITDBIT(rport, IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT))
193 
194 static uint8_t nxctl_inet_traffic_rule_masks[] = {
195 	ITRM(1, 1, 1, 1, 1),
196 	ITRM(1, 1, 1, 1, 0),
197 	ITRM(1, 1, 1, 0, 1),
198 	ITRM(1, 1, 1, 0, 0),
199 	ITRM(1, 1, 0, 1, 1),
200 	ITRM(1, 1, 0, 1, 0),
201 	ITRM(1, 1, 0, 0, 1),
202 	ITRM(1, 1, 0, 0, 0),
203 	ITRM(1, 0, 1, 1, 1),
204 	ITRM(1, 0, 1, 1, 0),
205 	ITRM(1, 0, 1, 0, 1),
206 	ITRM(1, 0, 1, 0, 0),
207 	ITRM(1, 0, 0, 1, 1),
208 	ITRM(1, 0, 0, 1, 0),
209 	ITRM(1, 0, 0, 0, 1),
210 	// ITRM(1, 0, 0, 0, 0), addr or port is required
211 	ITRM(0, 1, 1, 1, 1),
212 	ITRM(0, 1, 1, 1, 0),
213 	ITRM(0, 1, 1, 0, 1),
214 	ITRM(0, 1, 1, 0, 0),
215 	ITRM(0, 1, 0, 1, 1),
216 	ITRM(0, 1, 0, 1, 0),
217 	ITRM(0, 1, 0, 0, 1),
218 	ITRM(0, 1, 0, 0, 0),
219 	ITRM(0, 0, 1, 1, 1),
220 	ITRM(0, 0, 1, 1, 0),
221 	ITRM(0, 0, 1, 0, 1),
222 	ITRM(0, 0, 1, 0, 0),
223 	ITRM(0, 0, 0, 1, 1),
224 	ITRM(0, 0, 0, 1, 0),
225 	ITRM(0, 0, 0, 0, 1),
226 	// ITRM(0, 0, 0, 0, 0),
227 };
228 #define NINETRULEMASKS \
229     (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
230 
231 /* Per-interface lists of traffic rules */
232 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
233 struct nxctl_traffic_rule_inet_if {
234 	char rii_ifname[IFNAMSIZ];
235 	struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
236 	uint32_t rii_count;
237 	SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
238 };
239 
240 /* List of per-interface lists */
241 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
242 struct nxctl_traffic_rule_inet_storage {
243 	struct nxctl_traffic_rule_inet_if_head ris_if_list;
244 	uint32_t ris_count;
245 };
246 
247 /* Per-fd list kept at the nxctl */
248 SLIST_HEAD(nxctl_traffic_rule_head, nxctl_traffic_rule);
249 struct nxctl_traffic_rule_storage {
250 	struct nxctl_traffic_rule_head rs_list;
251 	uint32_t rs_count;
252 };
253 
254 static LCK_RW_DECLARE_ATTR(nxctl_traffic_rule_lock, &sk_lock_group, &sk_lock_attr);
255 #define NXTR_WLOCK() \
256     lck_rw_lock_exclusive(&nxctl_traffic_rule_lock)
257 #define NXTR_WUNLOCK() \
258     lck_rw_unlock_exclusive(&nxctl_traffic_rule_lock)
259 #define NXTR_RLOCK() \
260     lck_rw_lock_shared(&nxctl_traffic_rule_lock)
261 #define NXTR_RUNLOCK() \
262     lck_rw_unlock_shared(&nxctl_traffic_rule_lock)
263 
264 static struct nxctl_traffic_rule_type *find_traffic_rule_type(uint8_t type);
265 static void retain_traffic_rule(struct nxctl_traffic_rule *ntr);
266 static void release_traffic_rule(struct nxctl_traffic_rule *ntr);
267 static int remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
268     struct nxctl_traffic_rule **ntrp);
269 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
270     struct ifnet_ip_addr *a2);
271 static int notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags);
272 
273 #define NXCTL_TRAFFIC_RULE_TAG "com.apple.skywalk.nexus.traffic_rule"
274 static kern_allocation_name_t nxctl_traffic_rule_tag;
275 static struct nxctl_traffic_rule_type *inet_traffic_rule_type = NULL;
276 
277 /*
278  * If an interface attaches after rule(s) are added, this function is used
279  * retrieve the current rule count for that interface.
280  */
281 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)282 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
283 {
284 	struct nxctl_traffic_rule_inet_storage *rs;
285 	struct nxctl_traffic_rule_inet_if *rif;
286 	int err;
287 
288 	NXTR_RLOCK();
289 	rs = inet_traffic_rule_type->ntrt_storage;
290 	if (rs == NULL) {
291 		err = ENOENT;
292 		goto fail;
293 	}
294 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
295 		if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) == 0) {
296 			break;
297 		}
298 	}
299 	if (rif == NULL) {
300 		err = ENOENT;
301 		goto fail;
302 	}
303 	*count = rif->rii_count;
304 	NXTR_RUNLOCK();
305 	return 0;
306 fail:
307 	NXTR_RUNLOCK();
308 	return err;
309 }
310 
311 /*
312  * Used for finding the qset id associated with a traffic descriptor.
313  */
314 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)315 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
316     struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
317 {
318 	struct nxctl_traffic_rule_inet *__single ntri = NULL;
319 	struct nxctl_traffic_rule *__single ntr = NULL;
320 	int err;
321 
322 	NXTR_RLOCK();
323 	ASSERT(inet_traffic_rule_type != NULL);
324 	err = inet_traffic_rule_type->ntrt_find(inet_traffic_rule_type, ifname,
325 	    (struct ifnet_traffic_descriptor_common *)td, 0, &ntr);
326 	if (err != 0) {
327 		SK_ERR("rule find failed: %d", err);
328 		goto fail;
329 	}
330 	ntri = __container_of(ntr, struct nxctl_traffic_rule_inet, ntri_common);
331 	*qset_id = ntri->ntri_ra.ras_qset_id;
332 	NXTR_RUNLOCK();
333 	return 0;
334 fail:
335 	NXTR_RUNLOCK();
336 	return err;
337 }
338 
339 /*
340  * Based on flow_pkt_classify().
341  * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
342  */
343 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)344 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
345 {
346 	union {
347 		volatile struct ip *__indexable _iph;
348 		volatile struct ip6_hdr *__indexable _ip6;
349 	} _l3;
350 	#define iph _l3._iph
351 	#define ip6 _l3._ip6
352 	union {
353 		volatile struct tcphdr *_tcph;
354 		volatile struct udphdr *_udph;
355 	} _l4;
356 	#define tcph _l4._tcph
357 	#define udph _l4._udph
358 	uint8_t *pkt_buf, *l3_hdr;
359 	uint32_t bdlen, bdlim, bdoff, cls_len;
360 	size_t pkt_len;
361 	uint8_t ipv, l3hlen = 0; /* IP header length */
362 	uint16_t l3tlen = 0;     /* total length of IP packet */
363 	uint8_t l4hlen = 0;      /* TCP/UDP header length */
364 	uint16_t ulen = 0;       /* user data length */
365 	int err;
366 
367 	ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
368 	pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
369 
370 	MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
371 	cls_len = bdlim - bdoff;
372 	cls_len -= pkt->pkt_l2_len;
373 	cls_len = (uint32_t)MIN(cls_len, pkt_len);
374 	VERIFY(pkt_len >= cls_len);
375 	if (cls_len == 0) {
376 		SK_ERR("cls_len == 0");
377 		err = EINVAL;
378 		goto fail;
379 	}
380 	l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
381 	iph = (volatile struct ip *)(void *)l3_hdr;
382 	ipv = iph->ip_v;
383 
384 	switch (ipv) {
385 	case 4:
386 		if (cls_len < sizeof(struct ip)) {
387 			SK_ERR("cls_len < sizeof(struct ip) (%d < %d)",
388 			    cls_len, sizeof(struct ip));
389 			err = EINVAL;
390 			goto fail;
391 		}
392 		l3hlen = (uint8_t)(iph->ip_hl << 2);
393 		if (l3hlen < sizeof(struct ip)) {
394 			SK_ERR("l3hlen < sizeof(struct ip) (%d < %d)",
395 			    l3hlen, sizeof(struct ip));
396 			err = EINVAL;
397 			goto fail;
398 		}
399 		if (cls_len < l3hlen) {
400 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
401 			err = EINVAL;
402 			goto fail;
403 		}
404 		l3tlen = ntohs(iph->ip_len);
405 		if (l3tlen < l3hlen) {
406 			SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
407 			err = EINVAL;
408 			goto fail;
409 		}
410 		if (pkt_len < l3tlen) {
411 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
412 			err = EINVAL;
413 			goto fail;
414 		}
415 		td->inet_ipver = IPVERSION;
416 		td->inet_proto = iph->ip_p;
417 		bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
418 		    sizeof(iph->ip_src));
419 		bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
420 		    sizeof(iph->ip_dst));
421 		break;
422 	case 6:
423 		l3hlen = sizeof(struct ip6_hdr);
424 		if (cls_len < l3hlen) {
425 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
426 			err = EINVAL;
427 			goto fail;
428 		}
429 		l3tlen = l3hlen + ntohs(ip6->ip6_plen);
430 		if (pkt_len < l3tlen) {
431 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
432 			err = EINVAL;
433 			goto fail;
434 		}
435 		td->inet_ipver = IPV6_VERSION;
436 		td->inet_proto = ip6->ip6_nxt;
437 		bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
438 		    sizeof(ip6->ip6_src));
439 		bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
440 		    sizeof(ip6->ip6_dst));
441 		break;
442 	default:
443 		SK_ERR("ipv == %d", ipv);
444 		err = EINVAL;
445 		goto fail;
446 	}
447 	tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
448 	ulen = (l3tlen - l3hlen);
449 	if (td->inet_proto == IPPROTO_TCP) {
450 		if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
451 			SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
452 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*tcph),
453 			    ulen, sizeof(*tcph));
454 			err = EINVAL;
455 			goto fail;
456 		}
457 		l4hlen = (uint8_t)(tcph->th_off << 2);
458 		if (l4hlen < sizeof(*tcph)) {
459 			SK_ERR("l4hlen < sizeof(*tcph) (%d < %d)", l4hlen, sizeof(*tcph));
460 			err = EINVAL;
461 			goto fail;
462 		}
463 		if (l4hlen > ulen) {
464 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
465 			err = EINVAL;
466 			goto fail;
467 		}
468 		bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
469 		    sizeof(td->inet_lport));
470 		bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
471 		    sizeof(td->inet_rport));
472 	} else if (td->inet_proto == IPPROTO_UDP) {
473 		if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
474 			SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
475 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*udph),
476 			    ulen, sizeof(*udph));
477 			err = EINVAL;
478 			goto fail;
479 		}
480 		l4hlen = sizeof(*udph);
481 		if (l4hlen > ulen) {
482 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
483 			err = EINVAL;
484 			goto fail;
485 		}
486 		bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
487 		    sizeof(td->inet_lport));
488 		bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
489 		    sizeof(td->inet_rport));
490 	} else {
491 		err = ENOTSUP;
492 		goto fail;
493 	}
494 
495 	td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
496 	td->inet_common.itd_len = sizeof(*td);
497 	td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
498 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
499 	td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
500 	    IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
501 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
502 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
503 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
504 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
505 	return 0;
506 fail:
507 	DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
508 	    uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
509 	    int, err);
510 	bzero(td, sizeof(*td));
511 	return err;
512 	#undef iph
513 	#undef ip6
514 	#undef tcph
515 	#undef udph
516 }
517 
518 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)519 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
520     struct __kern_packet *pkt, uint64_t *qset_id)
521 {
522 	struct ifnet_traffic_descriptor_inet td;
523 	int err;
524 
525 	err = fill_inet_td(pkt, &td);
526 	if (err != 0) {
527 		return err;
528 	}
529 	return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
530 }
531 
532 void
nxctl_traffic_rule_init(void)533 nxctl_traffic_rule_init(void)
534 {
535 	ASSERT(nxctl_traffic_rule_tag == NULL);
536 	nxctl_traffic_rule_tag =
537 	    kern_allocation_name_allocate(NXCTL_TRAFFIC_RULE_TAG, 0);
538 	ASSERT(nxctl_traffic_rule_tag != NULL);
539 
540 	ASSERT(inet_traffic_rule_type == NULL);
541 	inet_traffic_rule_type =
542 	    find_traffic_rule_type(IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET);
543 	ASSERT(inet_traffic_rule_type != NULL);
544 }
545 
546 void
nxctl_traffic_rule_fini(void)547 nxctl_traffic_rule_fini(void)
548 {
549 	if (nxctl_traffic_rule_tag != NULL) {
550 		kern_allocation_name_release(nxctl_traffic_rule_tag);
551 		nxctl_traffic_rule_tag = NULL;
552 	}
553 	inet_traffic_rule_type = NULL;
554 }
555 
556 static struct ifnet_ip_addr v6_zeros_addr = {0};
557 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)558 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
559 {
560 	return memcmp(a1, a2, sizeof(*a1)) == 0;
561 }
562 
563 SK_NO_INLINE_ATTRIBUTE
564 static struct nxctl_traffic_rule_storage *
nxctl_traffic_rule_storage_create(void)565 nxctl_traffic_rule_storage_create(void)
566 {
567 	struct nxctl_traffic_rule_storage *rs;
568 
569 	rs = sk_alloc_type(struct nxctl_traffic_rule_storage,
570 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
571 	SLIST_INIT(&rs->rs_list);
572 	rs->rs_count = 0;
573 	return rs;
574 }
575 
576 SK_NO_INLINE_ATTRIBUTE
577 static void
nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage * rs)578 nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage *rs)
579 {
580 	ASSERT(rs->rs_count == 0);
581 	ASSERT(SLIST_EMPTY(&rs->rs_list));
582 	sk_free_type(struct nxctl_traffic_rule_storage, rs);
583 }
584 
585 /*
586  * This is meant to be called during closure of the nxctl's fd.
587  * This will cleanup all rules linked to this nxctl. Rules that
588  * are marked persistent won't be added to the nxctl list.
589  */
590 void
nxctl_traffic_rule_clean(struct nxctl * nxctl)591 nxctl_traffic_rule_clean(struct nxctl *nxctl)
592 {
593 	struct nxctl_traffic_rule_storage *rs;
594 	struct nxctl_traffic_rule *ntr, *next;
595 	int err;
596 
597 	lck_mtx_lock(&nxctl->nxctl_lock);
598 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
599 		lck_mtx_unlock(&nxctl->nxctl_lock);
600 		return;
601 	}
602 	ntr = SLIST_FIRST(&rs->rs_list);
603 	SLIST_INIT(&rs->rs_list);
604 	rs->rs_count = 0;
605 	nxctl_traffic_rule_storage_destroy(rs);
606 	nxctl->nxctl_traffic_rule_storage = NULL;
607 	lck_mtx_unlock(&nxctl->nxctl_lock);
608 
609 	while (ntr != NULL) {
610 		next = SLIST_NEXT(ntr, ntr_storage_link);
611 		/*
612 		 * Clearing the flag to tell remove_traffic_rule() not to
613 		 * remove from the nxctl list again.
614 		 */
615 		ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
616 
617 		/* Passing NULL because we already hold a reference */
618 		err = remove_traffic_rule(nxctl, ntr->ntr_uuid, NULL);
619 		if (err == 0) {
620 			(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
621 		}
622 		release_traffic_rule(ntr);
623 		ntr = next;
624 	}
625 }
626 
627 SK_NO_INLINE_ATTRIBUTE
628 static void
add_traffic_rule_to_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)629 add_traffic_rule_to_nxctl(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
630 {
631 	struct nxctl_traffic_rule_storage *rs;
632 
633 	lck_mtx_lock(&nxctl->nxctl_lock);
634 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
635 		rs = nxctl_traffic_rule_storage_create();
636 		nxctl->nxctl_traffic_rule_storage = rs;
637 	}
638 	ntr->ntr_flags |= NTR_FLAG_ON_NXCTL_LIST;
639 	retain_traffic_rule(ntr);
640 	SLIST_INSERT_HEAD(&rs->rs_list, ntr, ntr_storage_link);
641 	rs->rs_count++;
642 	lck_mtx_unlock(&nxctl->nxctl_lock);
643 }
644 
645 SK_NO_INLINE_ATTRIBUTE
646 static void
remove_traffic_rule_from_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)647 remove_traffic_rule_from_nxctl(struct nxctl *nxctl,
648     struct nxctl_traffic_rule *ntr)
649 {
650 	struct nxctl_traffic_rule_storage *rs;
651 
652 	lck_mtx_lock(&nxctl->nxctl_lock);
653 	if ((ntr->ntr_flags & NTR_FLAG_ON_NXCTL_LIST) == 0) {
654 		lck_mtx_unlock(&nxctl->nxctl_lock);
655 		return;
656 	}
657 	rs = nxctl->nxctl_traffic_rule_storage;
658 	SLIST_REMOVE(&rs->rs_list, ntr, nxctl_traffic_rule, ntr_storage_link);
659 	rs->rs_count--;
660 	ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
661 	release_traffic_rule(ntr);
662 	if (rs->rs_count == 0) {
663 		nxctl_traffic_rule_storage_destroy(rs);
664 		nxctl->nxctl_traffic_rule_storage = NULL;
665 	}
666 	lck_mtx_unlock(&nxctl->nxctl_lock);
667 }
668 
669 static int
inet_traffic_rule_validate(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)670 inet_traffic_rule_validate(struct nxctl_traffic_rule_type *type,
671     const char *ifname,
672     struct ifnet_traffic_descriptor_common *td,
673     struct ifnet_traffic_rule_action *ra)
674 {
675 #pragma unused(type)
676 	char buf[IFNAMSIZ];
677 	int unit, i;
678 	struct ifnet_traffic_descriptor_inet *tdi;
679 	uint8_t mask = 0, ipver, proto;
680 
681 	if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
682 		SK_ERR("invalid ifname: %s", ifname);
683 		return EINVAL;
684 	}
685 	if (td->itd_len != sizeof(*tdi)) {
686 		SK_ERR("invalid td len: expected %d, actual %d",
687 		    sizeof(*tdi), td->itd_len);
688 		return EINVAL;
689 	}
690 	if (td->itd_flags == 0 ||
691 	    (td->itd_flags &
692 	    ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
693 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
694 		SK_ERR("invalid td flags: 0x%x", td->itd_flags);
695 		return EINVAL;
696 	}
697 	tdi = (struct ifnet_traffic_descriptor_inet *)td;
698 	for (i = 0; i < NINETRULEMASKS; i++) {
699 		if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
700 			mask = tdi->inet_mask;
701 			break;
702 		}
703 	}
704 	if (mask == 0) {
705 		SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
706 		return EINVAL;
707 	}
708 	ipver = tdi->inet_ipver;
709 	if (ipver != IPVERSION && ipver != IPV6_VERSION) {
710 		SK_ERR("invalid inet ipver: 0x%x", ipver);
711 		return EINVAL;
712 	}
713 	proto = tdi->inet_proto;
714 	if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
715 		SK_ERR("invalid inet proto: %d", proto);
716 		return EINVAL;
717 	}
718 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
719 		if (ipver == IPVERSION) {
720 			if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
721 				SK_ERR("inet laddr v4 cannot be unspecified");
722 				return EINVAL;
723 			}
724 		} else {
725 			if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
726 				SK_ERR("inet laddr v4 cannot be unspecified");
727 				return EINVAL;
728 			}
729 		}
730 	}
731 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
732 		if (ipver == IPVERSION) {
733 			if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
734 				SK_ERR("inet raddr v6 cannot be unspecified");
735 				return EINVAL;
736 			}
737 		} else {
738 			if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
739 				SK_ERR("inet raddr v6 cannot be unspecified");
740 				return EINVAL;
741 			}
742 		}
743 	}
744 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
745 		if (tdi->inet_lport == 0) {
746 			SK_ERR("inet lport cannot be unspecified");
747 			return EINVAL;
748 		}
749 	}
750 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
751 		if (tdi->inet_rport == 0) {
752 			SK_ERR("inet rport cannot be unspecified");
753 			return EINVAL;
754 		}
755 	}
756 	if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
757 		SK_ERR("invalid ra len: expected %d, actual %d",
758 		    sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
759 		return EINVAL;
760 	}
761 	return 0;
762 }
763 
764 SK_NO_INLINE_ATTRIBUTE
765 static struct nxctl_traffic_rule_inet_storage *
inet_traffic_rule_storage_create(void)766 inet_traffic_rule_storage_create(void)
767 {
768 	struct nxctl_traffic_rule_inet_storage *rs;
769 
770 	rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
771 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
772 	SLIST_INIT(&rs->ris_if_list);
773 	rs->ris_count = 0;
774 	return rs;
775 }
776 
777 SK_NO_INLINE_ATTRIBUTE
778 static void
inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage * rs)779 inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage *rs)
780 {
781 	ASSERT(rs->ris_count == 0);
782 	ASSERT(SLIST_EMPTY(&rs->ris_if_list));
783 	sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
784 }
785 
786 SK_NO_INLINE_ATTRIBUTE
787 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)788 inet_traffic_rule_if_create(const char *ifname)
789 {
790 	struct nxctl_traffic_rule_inet_if *rif;
791 	int i;
792 
793 	rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
794 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
795 	for (i = 0; i < NINETRULEMASKS; i++) {
796 		SLIST_INIT(&rif->rii_lists[i]);
797 	}
798 	strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
799 	rif->rii_count = 0;
800 	return rif;
801 }
802 
803 SK_NO_INLINE_ATTRIBUTE
804 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)805 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
806 {
807 	int i;
808 
809 	for (i = 0; i < NINETRULEMASKS; i++) {
810 		ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
811 	}
812 	ASSERT(rif->rii_count == 0);
813 	sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
814 }
815 
816 SK_NO_INLINE_ATTRIBUTE
817 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)818 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
819     uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
820 {
821 	struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
822 	struct ifnet_traffic_descriptor_inet *tdi0;
823 	uint8_t mask;
824 	boolean_t exact;
825 
826 	VERIFY(strlcmp(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname)) == 0);
827 	tdi0 = &ntri->ntri_td;
828 
829 	exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
830 	mask = tdi0->inet_mask & tdi->inet_mask;
831 	if (exact) {
832 		ASSERT(tdi0->inet_mask == tdi->inet_mask);
833 	}
834 	ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
835 	if (tdi0->inet_ipver != tdi->inet_ipver) {
836 		DTRACE_SKYWALK2(ipver__mismatch,
837 		    uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
838 		return FALSE;
839 	}
840 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0 &&
841 	    tdi0->inet_proto != tdi->inet_proto) {
842 		DTRACE_SKYWALK2(proto__mismatch,
843 		    uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
844 		return FALSE;
845 	}
846 	if (tdi0->inet_ipver == IPVERSION) {
847 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
848 		    tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
849 			DTRACE_SKYWALK2(v4laddr__mismatch,
850 			    in_addr_t, tdi0->inet_laddr.iia_v4addr,
851 			    in_addr_t, tdi->inet_laddr.iia_v4addr);
852 			return FALSE;
853 		}
854 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
855 		    tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
856 			DTRACE_SKYWALK2(v4raddr__mismatch,
857 			    in_addr_t, tdi0->inet_raddr.iia_v4addr,
858 			    in_addr_t, tdi->inet_raddr.iia_v4addr);
859 			return FALSE;
860 		}
861 	} else {
862 		ASSERT(tdi0->inet_ipver == IPV6_VERSION);
863 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
864 		    !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
865 			DTRACE_SKYWALK2(v6laddr__mismatch,
866 			    struct in6_addr *, &tdi0->inet_laddr,
867 			    struct in6_addr *, &tdi->inet_laddr);
868 			return FALSE;
869 		}
870 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
871 		    !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
872 			DTRACE_SKYWALK2(v6raddr__mismatch,
873 			    struct in6_addr *, &tdi0->inet_raddr,
874 			    struct in6_addr *, &tdi->inet_raddr);
875 			return FALSE;
876 		}
877 	}
878 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
879 	    tdi0->inet_lport != tdi->inet_lport) {
880 		DTRACE_SKYWALK2(lport__mismatch,
881 		    uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
882 		return FALSE;
883 	}
884 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
885 	    tdi0->inet_rport != tdi->inet_rport) {
886 		DTRACE_SKYWALK2(rport__mismatch,
887 		    uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
888 		return FALSE;
889 	}
890 	return TRUE;
891 }
892 
893 static int
inet_traffic_rule_find(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)894 inet_traffic_rule_find(struct nxctl_traffic_rule_type *type, const char *ifname,
895     struct ifnet_traffic_descriptor_common *td, uint32_t flags,
896     struct nxctl_traffic_rule **ntrp)
897 {
898 	struct nxctl_traffic_rule_inet *ntri = NULL;
899 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
900 	struct nxctl_traffic_rule_inet_if *rif;
901 	struct ifnet_traffic_descriptor_inet *tdi =
902 	    (struct ifnet_traffic_descriptor_inet *)td;
903 	int i;
904 
905 	if (rs == NULL) {
906 		return ENOENT;
907 	}
908 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
909 		if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) != 0) {
910 			continue;
911 		}
912 		for (i = 0; i < NINETRULEMASKS; i++) {
913 			if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
914 			    tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
915 				continue;
916 			}
917 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
918 				if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
919 					*ntrp = (struct nxctl_traffic_rule *)ntri;
920 					return 0;
921 				}
922 			}
923 		}
924 	}
925 	return ENOENT;
926 }
927 
928 static int
inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type * type,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)929 inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type *type,
930     uuid_t uuid, struct nxctl_traffic_rule **ntrp)
931 {
932 	struct nxctl_traffic_rule_inet *ntri;
933 	struct nxctl_traffic_rule *ntr;
934 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
935 	struct nxctl_traffic_rule_inet_if *rif;
936 	int i;
937 
938 	if (rs == NULL) {
939 		return ENOENT;
940 	}
941 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
942 		for (i = 0; i < NINETRULEMASKS; i++) {
943 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
944 				ntr = &ntri->ntri_common;
945 				if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
946 					*ntrp = ntr;
947 					return 0;
948 				}
949 			}
950 		}
951 	}
952 	return ENOENT;
953 }
954 
955 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)956 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
957 {
958 	struct ifnet *ifp;
959 
960 	ifp = ifunit_ref(ifname);
961 	if (ifp == NULL) {
962 		DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
963 		return;
964 	}
965 	ifnet_update_traffic_rule_count(ifp, count);
966 	ifnet_decr_iorefcnt(ifp);
967 }
968 
969 static void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)970 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
971 {
972 	struct nxctl_traffic_rule_type *type = ntr->ntr_type;
973 	struct nxctl_traffic_rule_inet_storage *rs;
974 	struct nxctl_traffic_rule_inet_if *rif;
975 	struct nxctl_traffic_rule_inet *ntri =
976 	    (struct nxctl_traffic_rule_inet *)ntr;
977 	struct nxctl_traffic_rule_inet_head *list = NULL;
978 	int i;
979 	char *__null_terminated ntr_ifname = NULL;
980 	char *__null_terminated rii_ifname = NULL;
981 
982 	if ((rs = type->ntrt_storage) == NULL) {
983 		rs = inet_traffic_rule_storage_create();
984 		type->ntrt_storage = rs;
985 	}
986 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
987 		if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
988 			break;
989 		}
990 	}
991 	if (rif == NULL) {
992 		ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
993 		rif = inet_traffic_rule_if_create(ntr_ifname);
994 		SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
995 	}
996 	for (i = 0; i < NINETRULEMASKS; i++) {
997 		if (ntri->ntri_td.inet_mask ==
998 		    nxctl_inet_traffic_rule_masks[i]) {
999 			list = &rif->rii_lists[i];
1000 			break;
1001 		}
1002 	}
1003 	retain_traffic_rule(ntr);
1004 	ASSERT(list != NULL);
1005 	SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
1006 	/* per-interface count */
1007 	rif->rii_count++;
1008 	rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
1009 	inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
1010 
1011 	/* global count */
1012 	rs->ris_count++;
1013 }
1014 
1015 static void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)1016 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
1017 {
1018 	struct nxctl_traffic_rule_inet_storage *rs;
1019 	struct nxctl_traffic_rule_inet_if *rif;
1020 	struct nxctl_traffic_rule_inet *ntri =
1021 	    (struct nxctl_traffic_rule_inet *)ntr;
1022 	struct nxctl_traffic_rule_inet_head *list = NULL;
1023 	struct nxctl_traffic_rule_type *type;
1024 	int i;
1025 	char *__null_terminated rii_ifname = NULL;
1026 
1027 	type = ntr->ntr_type;
1028 	rs = type->ntrt_storage;
1029 	ASSERT(rs != NULL);
1030 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1031 		if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
1032 			break;
1033 		}
1034 	}
1035 	ASSERT(rif != NULL);
1036 	for (i = 0; i < NINETRULEMASKS; i++) {
1037 		if (ntri->ntri_td.inet_mask ==
1038 		    nxctl_inet_traffic_rule_masks[i]) {
1039 			list = &rif->rii_lists[i];
1040 			break;
1041 		}
1042 	}
1043 	ASSERT(list != NULL);
1044 	SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
1045 	rif->rii_count--;
1046 	rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
1047 	inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
1048 
1049 	rs->ris_count--;
1050 	release_traffic_rule(ntr);
1051 
1052 	if (rif->rii_count == 0) {
1053 		SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
1054 		inet_traffic_rule_if_destroy(rif);
1055 	}
1056 	if (rs->ris_count == 0) {
1057 		type->ntrt_storage = NULL;
1058 		inet_traffic_rule_storage_destroy(rs);
1059 	}
1060 }
1061 
1062 /*
1063  * XXX
1064  * This may need additional changes to ensure safety against detach/attach.
1065  * This is not an issue for the first consumer of llink interfaces, cellular,
1066  * which does not detach.
1067  */
1068 static int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)1069 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
1070 {
1071 	struct ifnet *ifp;
1072 	struct nx_netif *nif;
1073 	struct netif_qset *__single qset = NULL;
1074 	struct nxctl_traffic_rule_inet *ntri;
1075 	int err = 0;
1076 	char *__null_terminated ntr_ifname = NULL;
1077 
1078 	ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
1079 	ifp = ifunit_ref(ntr_ifname);
1080 	if (ifp == NULL) {
1081 		DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
1082 		err = ENXIO;
1083 		goto done;
1084 	}
1085 	nif = NA(ifp)->nifna_netif;
1086 	if (!NX_LLINK_PROV(nif->nif_nx)) {
1087 		DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
1088 		err = ENOTSUP;
1089 		goto done;
1090 	}
1091 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1092 	qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
1093 	if (qset == NULL || (qset->nqs_flags & NETIF_QSET_FLAG_EXT_INITED) == 0) {
1094 		DTRACE_SKYWALK1(qset__not__initialized, struct netif_qset *, qset);
1095 		err = ENXIO;
1096 		goto done;
1097 	}
1098 	err = nx_netif_notify_steering_info(nif, qset,
1099 	    (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
1100 	    ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
1101 done:
1102 	if (qset != NULL) {
1103 		nx_netif_qset_release(&qset);
1104 	}
1105 	if (ifp != NULL) {
1106 		ifnet_decr_iorefcnt(ifp);
1107 	}
1108 	return err;
1109 }
1110 
1111 static int
inet_traffic_rule_create(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1112 inet_traffic_rule_create(struct nxctl_traffic_rule_type *type,
1113     const char *ifname, struct ifnet_traffic_descriptor_common *td,
1114     struct ifnet_traffic_rule_action *ra, uint32_t flags,
1115     struct nxctl_traffic_rule **ntrp)
1116 {
1117 	struct nxctl_traffic_rule_inet *ntri;
1118 	struct nxctl_traffic_rule *ntr;
1119 	struct ifnet_traffic_descriptor_inet *tdi;
1120 	struct ifnet_traffic_rule_action_steer *ras;
1121 
1122 	ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
1123 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
1124 	ntr = &ntri->ntri_common;
1125 
1126 	ntr->ntr_type = type;
1127 	ntr->ntr_flags = flags;
1128 	uuid_generate(ntr->ntr_uuid);
1129 	os_ref_init(&ntr->ntr_refcnt, NULL);
1130 
1131 	strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
1132 	proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
1133 
1134 	tdi = __container_of(td, struct ifnet_traffic_descriptor_inet, inet_common);
1135 	ras = __container_of(ra, struct ifnet_traffic_rule_action_steer, ras_common);
1136 	bcopy(tdi, &ntri->ntri_td, sizeof(ntri->ntri_td));
1137 	bcopy(ras, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
1138 
1139 	*ntrp = ntr;
1140 	return 0;
1141 }
1142 
1143 static void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)1144 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
1145 {
1146 	struct nxctl_traffic_rule_inet *ntri;
1147 
1148 	ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
1149 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1150 	sk_free_type(struct nxctl_traffic_rule_inet, ntri);
1151 }
1152 
1153 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)1154 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
1155     struct nxctl_traffic_rule_inet_iocinfo *info)
1156 {
1157 	struct nxctl_traffic_rule *ntr;
1158 	struct nxctl_traffic_rule_generic_iocinfo *ginfo;
1159 
1160 	bzero(info, sizeof(*info));
1161 	ntr = &ntri->ntri_common;
1162 	ginfo = &info->tri_common;
1163 	_CASSERT(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
1164 	_CASSERT(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
1165 	uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
1166 	strbufcpy(ginfo->trg_procname, ntr->ntr_procname);
1167 	strbufcpy(ginfo->trg_ifname, ntr->ntr_ifname);
1168 	bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
1169 	bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
1170 }
1171 
1172 static int
inet_traffic_rule_get_all(struct nxctl_traffic_rule_type * type,uint32_t size,uint32_t * count,user_addr_t uaddr)1173 inet_traffic_rule_get_all(struct nxctl_traffic_rule_type *type, uint32_t size,
1174     uint32_t *count, user_addr_t uaddr)
1175 {
1176 	struct nxctl_traffic_rule_inet *ntri = NULL;
1177 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
1178 	struct nxctl_traffic_rule_inet_if *rif;
1179 	struct nxctl_traffic_rule_inet_iocinfo info;
1180 	int i, err;
1181 
1182 	if (size != sizeof(info)) {
1183 		SK_ERR("size: actual %d, expected %d", size, sizeof(info));
1184 		return EINVAL;
1185 	}
1186 	if (rs == NULL) {
1187 		*count = 0;
1188 		return 0;
1189 	}
1190 	if (*count < rs->ris_count) {
1191 		SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
1192 		return ENOBUFS;
1193 	}
1194 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1195 		for (i = 0; i < NINETRULEMASKS; i++) {
1196 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
1197 				convert_ntri_to_iocinfo(ntri, &info);
1198 				err = copyout(&info, uaddr, sizeof(info));
1199 				if (err != 0) {
1200 					SK_ERR("copyout failed: %d", err);
1201 					return err;
1202 				}
1203 				uaddr += sizeof(info);
1204 			}
1205 		}
1206 	}
1207 	*count = rs->ris_count;
1208 	return 0;
1209 }
1210 
1211 SK_NO_INLINE_ATTRIBUTE
1212 static void
retain_traffic_rule(struct nxctl_traffic_rule * ntr)1213 retain_traffic_rule(struct nxctl_traffic_rule *ntr)
1214 {
1215 #if (DEVELOPMENT || DEBUG)
1216 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1217 	DTRACE_SKYWALK2(ntr__retain, struct nxctl_traffic_rule *, ntr,
1218 	    os_ref_count_t, count);
1219 #endif
1220 	os_ref_retain(&ntr->ntr_refcnt);
1221 }
1222 
1223 SK_NO_INLINE_ATTRIBUTE
1224 static void
release_traffic_rule(struct nxctl_traffic_rule * ntr)1225 release_traffic_rule(struct nxctl_traffic_rule *ntr)
1226 {
1227 #if (DEVELOPMENT || DEBUG)
1228 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1229 	DTRACE_SKYWALK2(ntr__release, struct nxctl_traffic_rule *, ntr,
1230 	    os_ref_count_t, count);
1231 #endif
1232 	if (os_ref_release(&ntr->ntr_refcnt) == 0) {
1233 		ntr->ntr_type->ntrt_destroy(ntr);
1234 	}
1235 }
1236 
1237 SK_NO_INLINE_ATTRIBUTE
1238 static int
notify_traffic_rule(struct nxctl_traffic_rule * ntr,uint32_t flags)1239 notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags)
1240 {
1241 	return ntr->ntr_type->ntrt_notify(ntr, flags);
1242 }
1243 
1244 static void
link_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1245 link_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1246 {
1247 	/*
1248 	 * The persist flag means: do not clean up rule upon nxctl fd close.
1249 	 * This means we only add the rule to the nxctl list if persist
1250 	 * is not set.
1251 	 */
1252 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1253 		add_traffic_rule_to_nxctl(nxctl, ntr);
1254 	}
1255 	ntr->ntr_type->ntrt_link(ntr);
1256 }
1257 
1258 static void
unlink_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1259 unlink_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1260 {
1261 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1262 		remove_traffic_rule_from_nxctl(nxctl, ntr);
1263 	}
1264 	ntr->ntr_type->ntrt_unlink(ntr);
1265 }
1266 
1267 static int
find_traffic_rule_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1268 find_traffic_rule_by_uuid(uuid_t uuid, struct nxctl_traffic_rule **ntrp)
1269 {
1270 	int i, err;
1271 	struct nxctl_traffic_rule_type *ntrt;
1272 	struct nxctl_traffic_rule *__single ntr = NULL;
1273 
1274 	for (i = 0; i < NRULETYPES; i++) {
1275 		ntrt = &nxctl_rule_types[i];
1276 		err = ntrt->ntrt_find_by_uuid(ntrt, uuid, &ntr);
1277 		if (err == 0) {
1278 			ASSERT(ntr != NULL);
1279 			*ntrp = ntr;
1280 			return 0;
1281 		}
1282 	}
1283 	return ENOENT;
1284 }
1285 
1286 static struct nxctl_traffic_rule_type *
find_traffic_rule_type(uint8_t type)1287 find_traffic_rule_type(uint8_t type)
1288 {
1289 	int i;
1290 	struct nxctl_traffic_rule_type *ntrt;
1291 
1292 	for (i = 0; i < NRULETYPES; i++) {
1293 		ntrt = &nxctl_rule_types[i];
1294 		if (ntrt->ntrt_type == type) {
1295 			return ntrt;
1296 		}
1297 	}
1298 	return NULL;
1299 }
1300 
1301 SK_NO_INLINE_ATTRIBUTE
1302 static int
add_traffic_rule(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1303 add_traffic_rule(struct nxctl *nxctl, const char *ifname,
1304     struct ifnet_traffic_descriptor_common *td,
1305     struct ifnet_traffic_rule_action *ra,
1306     uint32_t flags,
1307     struct nxctl_traffic_rule **ntrp)
1308 {
1309 	struct nxctl_traffic_rule_type *type = NULL;
1310 	struct nxctl_traffic_rule *__single ntr = NULL;
1311 	int err;
1312 
1313 	NXTR_WLOCK();
1314 	type = find_traffic_rule_type(td->itd_type);
1315 	if (type == NULL) {
1316 		SK_ERR("rule type %x not found", td->itd_type);
1317 		err = EINVAL;
1318 		goto fail;
1319 	}
1320 	err = type->ntrt_validate(type, ifname, td, ra);
1321 	if (err != 0) {
1322 		SK_ERR("rule validate failed: %d", err);
1323 		goto fail;
1324 	}
1325 	err = type->ntrt_find(type, ifname, td, NTR_FIND_FLAG_EXACT, &ntr);
1326 	if (err == 0) {
1327 		SK_ERR("rule already exists");
1328 		ASSERT(ntr != NULL);
1329 		err = EEXIST;
1330 		goto fail;
1331 	} else if (err != ENOENT) {
1332 		SK_ERR("rule find failed: %d", err);
1333 		goto fail;
1334 	}
1335 	err = type->ntrt_create(type, ifname, td, ra, flags, &ntr);
1336 	if (err != 0) {
1337 		SK_ERR("rule create failed: %d", err);
1338 		goto fail;
1339 	}
1340 	link_traffic_rule(nxctl, ntr);
1341 	if (ntrp != NULL) {
1342 		retain_traffic_rule(ntr);
1343 		*ntrp = ntr;
1344 	}
1345 	NXTR_WUNLOCK();
1346 	return 0;
1347 fail:
1348 	NXTR_WUNLOCK();
1349 	return err;
1350 }
1351 
1352 
1353 SK_NO_INLINE_ATTRIBUTE
1354 static int
remove_traffic_rule(struct nxctl * nxctl,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1355 remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
1356     struct nxctl_traffic_rule **ntrp)
1357 {
1358 	struct nxctl_traffic_rule *__single ntr;
1359 	int err;
1360 
1361 	NXTR_WLOCK();
1362 	err = find_traffic_rule_by_uuid(uuid, &ntr);
1363 	if (err != 0) {
1364 		SK_ERR("traffic rule not found");
1365 		NXTR_WUNLOCK();
1366 		return err;
1367 	}
1368 	if (ntrp != NULL) {
1369 		retain_traffic_rule(ntr);
1370 		*ntrp = ntr;
1371 	}
1372 	unlink_traffic_rule(nxctl, ntr);
1373 	/* release initial reference */
1374 	release_traffic_rule(ntr);
1375 	NXTR_WUNLOCK();
1376 	return 0;
1377 }
1378 
1379 static uint32_t
convert_traffic_rule_ioc_flags(uint32_t flags)1380 convert_traffic_rule_ioc_flags(uint32_t flags)
1381 {
1382 	uint32_t f = 0;
1383 
1384 	if ((flags & NXIOC_ADD_TRAFFIC_RULE_FLAG_PERSIST) != 0) {
1385 		f |= NTR_FLAG_PERSIST;
1386 	}
1387 	return f;
1388 }
1389 
1390 SK_NO_INLINE_ATTRIBUTE
1391 static int
add_traffic_rule_generic(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,uuid_t * uuid)1392 add_traffic_rule_generic(struct nxctl *nxctl, const char *ifname,
1393     struct ifnet_traffic_descriptor_common *td,
1394     struct ifnet_traffic_rule_action *ra, uint32_t flags, uuid_t *uuid)
1395 {
1396 	struct nxctl_traffic_rule *__single ntr;
1397 	int err;
1398 
1399 	err = add_traffic_rule(nxctl, ifname, td, ra, flags, &ntr);
1400 	if (err != 0) {
1401 		return err;
1402 	}
1403 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_ADD);
1404 	uuid_copy(*uuid, ntr->ntr_uuid);
1405 	release_traffic_rule(ntr);
1406 	return 0;
1407 }
1408 
1409 int
nxioctl_add_traffic_rule_inet(struct nxctl * nxctl,caddr_t data,proc_t procp)1410 nxioctl_add_traffic_rule_inet(struct nxctl *nxctl, caddr_t data, proc_t procp)
1411 {
1412 #pragma unused(procp)
1413 	struct nxctl_add_traffic_rule_inet_iocargs *args =
1414 	    (struct nxctl_add_traffic_rule_inet_iocargs *)(void *)data;
1415 	char *__null_terminated atri_ifname = NULL;
1416 
1417 	atri_ifname = __unsafe_null_terminated_from_indexable(args->atri_ifname);
1418 
1419 	return add_traffic_rule_generic(nxctl, atri_ifname,
1420 	           (struct ifnet_traffic_descriptor_common *)&args->atri_td,
1421 	           (struct ifnet_traffic_rule_action *)&args->atri_ra,
1422 	           convert_traffic_rule_ioc_flags(args->atri_flags),
1423 	           &args->atri_uuid);
1424 }
1425 
1426 int
nxioctl_remove_traffic_rule(struct nxctl * nxctl,caddr_t data,proc_t procp)1427 nxioctl_remove_traffic_rule(struct nxctl *nxctl, caddr_t data, proc_t procp)
1428 {
1429 #pragma unused(procp)
1430 	struct nxctl_remove_traffic_rule_iocargs *args =
1431 	    (struct nxctl_remove_traffic_rule_iocargs *)(void *)data;
1432 	struct nxctl_traffic_rule *__single ntr;
1433 	int err;
1434 
1435 	err = remove_traffic_rule(nxctl, args->rtr_uuid, &ntr);
1436 	if (err != 0) {
1437 		return err;
1438 	}
1439 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
1440 	release_traffic_rule(ntr);
1441 	return 0;
1442 }
1443 
1444 int
nxioctl_get_traffic_rules(struct nxctl * nxctl,caddr_t data,proc_t procp)1445 nxioctl_get_traffic_rules(struct nxctl *nxctl, caddr_t data, proc_t procp)
1446 {
1447 #pragma unused(nxctl)
1448 	struct nxctl_get_traffic_rules_iocargs *args =
1449 	    (struct nxctl_get_traffic_rules_iocargs *)(void *)data;
1450 	struct nxctl_traffic_rule_type *type;
1451 	user_addr_t uaddr;
1452 	int err;
1453 
1454 	NXTR_RLOCK();
1455 	type = find_traffic_rule_type(args->gtr_type);
1456 	if (type == NULL) {
1457 		SK_ERR("rule type %x not found", args->gtr_type);
1458 		err = EINVAL;
1459 		goto fail;
1460 	}
1461 	uaddr = proc_is64bit(procp) ? args->gtr_buf64 :
1462 	    CAST_USER_ADDR_T(args->gtr_buf);
1463 	err = type->ntrt_get_all(type, args->gtr_size, &args->gtr_count, uaddr);
1464 	if (err != 0) {
1465 		goto fail;
1466 	}
1467 	NXTR_RUNLOCK();
1468 	return 0;
1469 fail:
1470 	NXTR_RUNLOCK();
1471 	return err;
1472 }
1473