xref: /xnu-10063.101.15/bsd/skywalk/nexus/nexus_traffic_rule.c (revision 94d3b452840153a99b38a3a9659680b2a006908e)
1 /*
2  * Copyright (c) 2022 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 #include <skywalk/os_skywalk_private.h>
29 #include <skywalk/nexus/netif/nx_netif.h>
30 #include <netinet/ip.h>
31 #include <netinet/ip6.h>
32 #include <netinet/tcp.h>
33 #include <netinet/udp.h>
34 #include <sys/sdt.h>
35 
36 /*
37  * Implementation of nexus traffic rules APIs.
38  */
39 
40 struct nxctl_traffic_rule_type;
41 struct nxctl_traffic_rule;
42 
43 /*
44  * These callbacks need to be implemented for each rule type.
45  */
46 
47 /* Validate user provided parameters. */
48 typedef int (nxctl_traffic_rule_validate_cb_t)(
49 	struct nxctl_traffic_rule_type *type,
50 	const char *ifname,
51 	struct ifnet_traffic_descriptor_common *td,
52 	struct ifnet_traffic_rule_action *ra);
53 /*
54  * Each rule type has its own global structure for storing rules.
55  * These callbacks access this global structure.
56  */
57 #define NTR_FIND_FLAG_EXACT 0x0001
58 typedef int (nxctl_traffic_rule_find_cb_t)(
59 	struct nxctl_traffic_rule_type *type,
60 	const char *ifname,
61 	struct ifnet_traffic_descriptor_common *td,
62 	uint32_t flags,
63 	struct nxctl_traffic_rule **ntrp);
64 
65 typedef int (nxctl_traffic_rule_find_by_uuid_cb_t)(
66 	struct nxctl_traffic_rule_type *type,
67 	uuid_t uuid,
68 	struct nxctl_traffic_rule **ntrp);
69 
70 typedef void (nxctl_traffic_rule_link_cb_t)(
71 	struct nxctl_traffic_rule *ntr);
72 
73 typedef void (nxctl_traffic_rule_unlink_cb_t)(
74 	struct nxctl_traffic_rule *ntr);
75 
76 /*
77  * Notifies lower layers of the addition/removal of a rule.
78  * This is called outside of nxctl_traffic_rule_lock to avoid potential
79  * locking issues.
80  */
81 #define NTR_NOTIFY_FLAG_ADD 0x0001
82 #define NTR_NOTIFY_FLAG_REMOVE 0x0002
83 typedef int (nxctl_traffic_rule_notify_cb_t)(
84 	struct nxctl_traffic_rule *ntr,
85 	uint32_t flags);
86 
87 /*
88  * Create/Destroy callbacks for a rule type.
89  */
90 typedef int (nxctl_traffic_rule_create_cb_t)(
91 	struct nxctl_traffic_rule_type *type,
92 	const char *ifname,
93 	struct ifnet_traffic_descriptor_common *td,
94 	struct ifnet_traffic_rule_action *ra,
95 	uint32_t flags,
96 	struct nxctl_traffic_rule **ntrp);
97 
98 typedef void (nxctl_traffic_rule_destroy_cb_t)(
99 	struct nxctl_traffic_rule *ntr);
100 
101 /*
102  * This is used for copying all rules for a type (including generic
103  * and type-specific info) to userspace.
104  */
105 typedef int (nxctl_traffic_rule_get_all_cb_t)(
106 	struct nxctl_traffic_rule_type *type,
107 	uint32_t size,
108 	uint32_t *count,
109 	user_addr_t uaddr);
110 
111 struct nxctl_traffic_rule_type {
112 	uint8_t ntrt_type;
113 	nxctl_traffic_rule_validate_cb_t *ntrt_validate;
114 	nxctl_traffic_rule_find_cb_t *ntrt_find;
115 	nxctl_traffic_rule_find_by_uuid_cb_t *ntrt_find_by_uuid;
116 	nxctl_traffic_rule_link_cb_t *ntrt_link;
117 	nxctl_traffic_rule_unlink_cb_t *ntrt_unlink;
118 	nxctl_traffic_rule_notify_cb_t *ntrt_notify;
119 	nxctl_traffic_rule_create_cb_t *ntrt_create;
120 	nxctl_traffic_rule_destroy_cb_t *ntrt_destroy;
121 	nxctl_traffic_rule_get_all_cb_t *ntrt_get_all;
122 	void *ntrt_storage;
123 };
124 
125 static nxctl_traffic_rule_validate_cb_t inet_traffic_rule_validate;
126 static nxctl_traffic_rule_find_cb_t inet_traffic_rule_find;
127 static nxctl_traffic_rule_find_by_uuid_cb_t inet_traffic_rule_find_by_uuid;
128 static nxctl_traffic_rule_link_cb_t inet_traffic_rule_link;
129 static nxctl_traffic_rule_unlink_cb_t inet_traffic_rule_unlink;
130 static nxctl_traffic_rule_notify_cb_t inet_traffic_rule_notify;
131 static nxctl_traffic_rule_create_cb_t inet_traffic_rule_create;
132 static nxctl_traffic_rule_destroy_cb_t inet_traffic_rule_destroy;
133 static nxctl_traffic_rule_get_all_cb_t inet_traffic_rule_get_all;
134 
135 static struct nxctl_traffic_rule_type nxctl_rule_types[] = {
136 	{
137 		.ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET,
138 		.ntrt_validate = inet_traffic_rule_validate,
139 		.ntrt_find = inet_traffic_rule_find,
140 		.ntrt_find_by_uuid = inet_traffic_rule_find_by_uuid,
141 		.ntrt_link = inet_traffic_rule_link,
142 		.ntrt_unlink = inet_traffic_rule_unlink,
143 		.ntrt_notify = inet_traffic_rule_notify,
144 		.ntrt_create = inet_traffic_rule_create,
145 		.ntrt_destroy = inet_traffic_rule_destroy,
146 		.ntrt_get_all = inet_traffic_rule_get_all,
147 	},
148 };
149 #define NRULETYPES \
150     (sizeof(nxctl_rule_types)/sizeof(struct nxctl_traffic_rule_type))
151 
152 /*
153  * Generic traffic rule.
154  * Contains fields common to all traffic rules.
155  */
156 #define NTR_FLAG_PERSIST 0x0001
157 #define NTR_FLAG_ON_NXCTL_LIST 0x0002
158 struct nxctl_traffic_rule {
159 	struct nxctl_traffic_rule_type *ntr_type;
160 	uint32_t ntr_flags;
161 	os_refcnt_t ntr_refcnt;
162 	uuid_t ntr_uuid;
163 	char ntr_procname[NTR_PROCNAME_SZ];
164 	char ntr_ifname[IFNAMSIZ];
165 	SLIST_ENTRY(nxctl_traffic_rule) ntr_storage_link;
166 };
167 
168 /*
169  * Inet-specific traffic rule.
170  */
171 struct nxctl_traffic_rule_inet {
172 	struct nxctl_traffic_rule ntri_common;
173 	SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
174 	struct ifnet_traffic_descriptor_inet ntri_td;
175 	struct ifnet_traffic_rule_action_steer ntri_ra;
176 };
177 
178 /*
179  * Currently supported tuple types.
180  */
181 #define ITDBIT(set, bit) (((set) != 0) ? (bit) : 0)
182 #define ITRM(proto, laddr, raddr, lport, rport) \
183 	(IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER | \
184 	ITDBIT(proto, IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) | \
185 	ITDBIT(laddr, IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) | \
186 	ITDBIT(raddr, IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) | \
187 	ITDBIT(lport, IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) | \
188 	ITDBIT(rport, IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT))
189 
190 static uint8_t nxctl_inet_traffic_rule_masks[] = {
191 	ITRM(1, 1, 1, 1, 1),
192 	ITRM(1, 1, 1, 1, 0),
193 	ITRM(1, 1, 1, 0, 1),
194 	ITRM(1, 1, 1, 0, 0),
195 	ITRM(1, 1, 0, 1, 1),
196 	ITRM(1, 1, 0, 1, 0),
197 	ITRM(1, 1, 0, 0, 1),
198 	ITRM(1, 1, 0, 0, 0),
199 	ITRM(1, 0, 1, 1, 1),
200 	ITRM(1, 0, 1, 1, 0),
201 	ITRM(1, 0, 1, 0, 1),
202 	ITRM(1, 0, 1, 0, 0),
203 	ITRM(1, 0, 0, 1, 1),
204 	ITRM(1, 0, 0, 1, 0),
205 	ITRM(1, 0, 0, 0, 1),
206 	// ITRM(1, 0, 0, 0, 0), addr or port is required
207 	ITRM(0, 1, 1, 1, 1),
208 	ITRM(0, 1, 1, 1, 0),
209 	ITRM(0, 1, 1, 0, 1),
210 	ITRM(0, 1, 1, 0, 0),
211 	ITRM(0, 1, 0, 1, 1),
212 	ITRM(0, 1, 0, 1, 0),
213 	ITRM(0, 1, 0, 0, 1),
214 	ITRM(0, 1, 0, 0, 0),
215 	ITRM(0, 0, 1, 1, 1),
216 	ITRM(0, 0, 1, 1, 0),
217 	ITRM(0, 0, 1, 0, 1),
218 	ITRM(0, 0, 1, 0, 0),
219 	ITRM(0, 0, 0, 1, 1),
220 	ITRM(0, 0, 0, 1, 0),
221 	ITRM(0, 0, 0, 0, 1),
222 	// ITRM(0, 0, 0, 0, 0),
223 };
224 #define NINETRULEMASKS \
225     (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
226 
227 /* Per-interface lists of traffic rules */
228 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
229 struct nxctl_traffic_rule_inet_if {
230 	char rii_ifname[IFNAMSIZ];
231 	struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
232 	uint32_t rii_count;
233 	SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
234 };
235 
236 /* List of per-interface lists */
237 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
238 struct nxctl_traffic_rule_inet_storage {
239 	struct nxctl_traffic_rule_inet_if_head ris_if_list;
240 	uint32_t ris_count;
241 };
242 
243 /* Per-fd list kept at the nxctl */
244 SLIST_HEAD(nxctl_traffic_rule_head, nxctl_traffic_rule);
245 struct nxctl_traffic_rule_storage {
246 	struct nxctl_traffic_rule_head rs_list;
247 	uint32_t rs_count;
248 };
249 
250 static LCK_RW_DECLARE_ATTR(nxctl_traffic_rule_lock, &sk_lock_group, &sk_lock_attr);
251 #define NXTR_WLOCK() \
252     lck_rw_lock_exclusive(&nxctl_traffic_rule_lock)
253 #define NXTR_WUNLOCK() \
254     lck_rw_unlock_exclusive(&nxctl_traffic_rule_lock)
255 #define NXTR_RLOCK() \
256     lck_rw_lock_shared(&nxctl_traffic_rule_lock)
257 #define NXTR_RUNLOCK() \
258     lck_rw_unlock_shared(&nxctl_traffic_rule_lock)
259 
260 static struct nxctl_traffic_rule_type *find_traffic_rule_type(uint8_t type);
261 static void retain_traffic_rule(struct nxctl_traffic_rule *ntr);
262 static void release_traffic_rule(struct nxctl_traffic_rule *ntr);
263 static int remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
264     struct nxctl_traffic_rule **ntrp);
265 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
266     struct ifnet_ip_addr *a2);
267 static int notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags);
268 
269 #define NXCTL_TRAFFIC_RULE_TAG "com.apple.skywalk.nexus.traffic_rule"
270 static kern_allocation_name_t nxctl_traffic_rule_tag;
271 static struct nxctl_traffic_rule_type *inet_traffic_rule_type = NULL;
272 
273 /*
274  * If a interface attaches after rule(s) are added, this function is used
275  * retrieve the current rule count for that interface.
276  */
277 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)278 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
279 {
280 	struct nxctl_traffic_rule_inet_storage *rs;
281 	struct nxctl_traffic_rule_inet_if *rif;
282 	int err;
283 
284 	NXTR_RLOCK();
285 	rs = inet_traffic_rule_type->ntrt_storage;
286 	if (rs == NULL) {
287 		err = ENOENT;
288 		goto fail;
289 	}
290 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
291 		if (strcmp(rif->rii_ifname, ifname) == 0) {
292 			break;
293 		}
294 	}
295 	if (rif == NULL) {
296 		err = ENOENT;
297 		goto fail;
298 	}
299 	*count = rif->rii_count;
300 	NXTR_RUNLOCK();
301 	return 0;
302 fail:
303 	NXTR_RUNLOCK();
304 	return err;
305 }
306 
307 /*
308  * Used for finding the qset id associated with a traffic descriptor.
309  */
310 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)311 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
312     struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
313 {
314 	struct nxctl_traffic_rule_inet *ntri = NULL;
315 	int err;
316 
317 	NXTR_RLOCK();
318 	ASSERT(inet_traffic_rule_type != NULL);
319 	err = inet_traffic_rule_type->ntrt_find(inet_traffic_rule_type, ifname,
320 	    (struct ifnet_traffic_descriptor_common *)td, 0,
321 	    (struct nxctl_traffic_rule **)&ntri);
322 	if (err != 0) {
323 		SK_ERR("rule find failed: %d", err);
324 		goto fail;
325 	}
326 	*qset_id = ntri->ntri_ra.ras_qset_id;
327 	NXTR_RUNLOCK();
328 	return 0;
329 fail:
330 	NXTR_RUNLOCK();
331 	return err;
332 }
333 
334 /*
335  * Based on flow_pkt_classify().
336  * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
337  */
338 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)339 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
340 {
341 	union {
342 		volatile struct ip *_iph;
343 		volatile struct ip6_hdr *_ip6;
344 	} _l3;
345 	#define iph _l3._iph
346 	#define ip6 _l3._ip6
347 	union {
348 		volatile struct tcphdr *_tcph;
349 		volatile struct udphdr *_udph;
350 	} _l4;
351 	#define tcph _l4._tcph
352 	#define udph _l4._udph
353 	uint8_t *pkt_buf, *l3_hdr;
354 	uint32_t bdlen, bdlim, bdoff, cls_len;
355 	size_t pkt_len;
356 	uint8_t ipv, l3hlen = 0; /* IP header length */
357 	uint16_t l3tlen = 0;     /* total length of IP packet */
358 	uint8_t l4hlen = 0;      /* TCP/UDP header length */
359 	uint16_t ulen = 0;       /* user data length */
360 	int err;
361 
362 	ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
363 	pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
364 
365 	MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
366 	cls_len = bdlim - bdoff;
367 	cls_len -= pkt->pkt_l2_len;
368 	cls_len = (uint32_t)MIN(cls_len, pkt_len);
369 	VERIFY(pkt_len >= cls_len);
370 	if (cls_len == 0) {
371 		SK_ERR("cls_len == 0");
372 		err = EINVAL;
373 		goto fail;
374 	}
375 	l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
376 	iph = (volatile struct ip *)(void *)l3_hdr;
377 	ipv = iph->ip_v;
378 
379 	switch (ipv) {
380 	case 4:
381 		if (cls_len < sizeof(struct ip)) {
382 			SK_ERR("cls_len < sizeof(struct ip) (%d < %d)",
383 			    cls_len, sizeof(struct ip));
384 			err = EINVAL;
385 			goto fail;
386 		}
387 		l3hlen = (uint8_t)(iph->ip_hl << 2);
388 		if (l3hlen < sizeof(struct ip)) {
389 			SK_ERR("l3hlen < sizeof(struct ip) (%d < %d)",
390 			    l3hlen, sizeof(struct ip));
391 			err = EINVAL;
392 			goto fail;
393 		}
394 		if (cls_len < l3hlen) {
395 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
396 			err = EINVAL;
397 			goto fail;
398 		}
399 		l3tlen = ntohs(iph->ip_len);
400 		if (l3tlen < l3hlen) {
401 			SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
402 			err = EINVAL;
403 			goto fail;
404 		}
405 		if (pkt_len < l3tlen) {
406 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
407 			err = EINVAL;
408 			goto fail;
409 		}
410 		td->inet_ipver = IPVERSION;
411 		td->inet_proto = iph->ip_p;
412 		bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
413 		    sizeof(iph->ip_src));
414 		bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
415 		    sizeof(iph->ip_dst));
416 		break;
417 	case 6:
418 		l3hlen = sizeof(struct ip6_hdr);
419 		if (cls_len < l3hlen) {
420 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
421 			err = EINVAL;
422 			goto fail;
423 		}
424 		l3tlen = l3hlen + ntohs(ip6->ip6_plen);
425 		if (pkt_len < l3tlen) {
426 			SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
427 			err = EINVAL;
428 			goto fail;
429 		}
430 		td->inet_ipver = IPV6_VERSION;
431 		td->inet_proto = ip6->ip6_nxt;
432 		bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
433 		    sizeof(ip6->ip6_src));
434 		bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
435 		    sizeof(ip6->ip6_dst));
436 		break;
437 	default:
438 		SK_ERR("ipv == %d", ipv);
439 		err = EINVAL;
440 		goto fail;
441 	}
442 	tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
443 	ulen = (l3tlen - l3hlen);
444 	if (td->inet_proto == IPPROTO_TCP) {
445 		if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
446 			SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
447 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*tcph),
448 			    ulen, sizeof(*tcph));
449 			err = EINVAL;
450 			goto fail;
451 		}
452 		l4hlen = (uint8_t)(tcph->th_off << 2);
453 		if (l4hlen < sizeof(*tcph)) {
454 			SK_ERR("l4hlen < sizeof(*tcph) (%d < %d)", l4hlen, sizeof(*tcph));
455 			err = EINVAL;
456 			goto fail;
457 		}
458 		if (l4hlen > ulen) {
459 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
460 			err = EINVAL;
461 			goto fail;
462 		}
463 		bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
464 		    sizeof(td->inet_lport));
465 		bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
466 		    sizeof(td->inet_rport));
467 	} else if (td->inet_proto == IPPROTO_UDP) {
468 		if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
469 			SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
470 			    "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*udph),
471 			    ulen, sizeof(*udph));
472 			err = EINVAL;
473 			goto fail;
474 		}
475 		l4hlen = sizeof(*udph);
476 		if (l4hlen > ulen) {
477 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
478 			err = EINVAL;
479 			goto fail;
480 		}
481 		bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
482 		    sizeof(td->inet_lport));
483 		bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
484 		    sizeof(td->inet_rport));
485 	} else {
486 		err = ENOTSUP;
487 		goto fail;
488 	}
489 
490 	td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
491 	td->inet_common.itd_len = sizeof(*td);
492 	td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
493 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
494 	td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
495 	    IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
496 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
497 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
498 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
499 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
500 	return 0;
501 fail:
502 	DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
503 	    uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
504 	    int, err);
505 	bzero(td, sizeof(*td));
506 	return err;
507 	#undef iph
508 	#undef ip6
509 	#undef tcph
510 	#undef udph
511 }
512 
513 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)514 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
515     struct __kern_packet *pkt, uint64_t *qset_id)
516 {
517 	struct ifnet_traffic_descriptor_inet td;
518 	int err;
519 
520 	err = fill_inet_td(pkt, &td);
521 	if (err != 0) {
522 		return err;
523 	}
524 	return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
525 }
526 
527 void
nxctl_traffic_rule_init(void)528 nxctl_traffic_rule_init(void)
529 {
530 	ASSERT(nxctl_traffic_rule_tag == NULL);
531 	nxctl_traffic_rule_tag =
532 	    kern_allocation_name_allocate(NXCTL_TRAFFIC_RULE_TAG, 0);
533 	ASSERT(nxctl_traffic_rule_tag != NULL);
534 
535 	ASSERT(inet_traffic_rule_type == NULL);
536 	inet_traffic_rule_type =
537 	    find_traffic_rule_type(IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET);
538 	ASSERT(inet_traffic_rule_type != NULL);
539 }
540 
541 void
nxctl_traffic_rule_fini(void)542 nxctl_traffic_rule_fini(void)
543 {
544 	if (nxctl_traffic_rule_tag != NULL) {
545 		kern_allocation_name_release(nxctl_traffic_rule_tag);
546 		nxctl_traffic_rule_tag = NULL;
547 	}
548 	inet_traffic_rule_type = NULL;
549 }
550 
551 static struct ifnet_ip_addr v6_zeros_addr = {0};
552 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)553 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
554 {
555 	return memcmp(a1, a2, sizeof(*a1)) == 0;
556 }
557 
558 SK_NO_INLINE_ATTRIBUTE
559 static struct nxctl_traffic_rule_storage *
nxctl_traffic_rule_storage_create(void)560 nxctl_traffic_rule_storage_create(void)
561 {
562 	struct nxctl_traffic_rule_storage *rs;
563 
564 	rs = sk_alloc_type(struct nxctl_traffic_rule_storage,
565 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
566 	SLIST_INIT(&rs->rs_list);
567 	rs->rs_count = 0;
568 	return rs;
569 }
570 
571 SK_NO_INLINE_ATTRIBUTE
572 static void
nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage * rs)573 nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage *rs)
574 {
575 	ASSERT(rs->rs_count == 0);
576 	ASSERT(SLIST_EMPTY(&rs->rs_list));
577 	sk_free_type(struct nxctl_traffic_rule_storage, rs);
578 }
579 
580 /*
581  * This is meant to be called during closure of the nxctl's fd.
582  * This will cleanup all rules linked to this nxctl. Rules that
583  * are marked persistent won't be added to the nxctl list.
584  */
585 void
nxctl_traffic_rule_clean(struct nxctl * nxctl)586 nxctl_traffic_rule_clean(struct nxctl *nxctl)
587 {
588 	struct nxctl_traffic_rule_storage *rs;
589 	struct nxctl_traffic_rule *ntr, *next;
590 	int err;
591 
592 	lck_mtx_lock(&nxctl->nxctl_lock);
593 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
594 		lck_mtx_unlock(&nxctl->nxctl_lock);
595 		return;
596 	}
597 	ntr = SLIST_FIRST(&rs->rs_list);
598 	SLIST_INIT(&rs->rs_list);
599 	rs->rs_count = 0;
600 	nxctl_traffic_rule_storage_destroy(rs);
601 	nxctl->nxctl_traffic_rule_storage = NULL;
602 	lck_mtx_unlock(&nxctl->nxctl_lock);
603 
604 	while (ntr != NULL) {
605 		next = SLIST_NEXT(ntr, ntr_storage_link);
606 		/*
607 		 * Clearing the flag to tell remove_traffic_rule() not to
608 		 * remove from the nxctl list again.
609 		 */
610 		ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
611 
612 		/* Passing NULL because we already hold a reference */
613 		err = remove_traffic_rule(nxctl, ntr->ntr_uuid, NULL);
614 		if (err == 0) {
615 			(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
616 		}
617 		release_traffic_rule(ntr);
618 		ntr = next;
619 	}
620 }
621 
622 SK_NO_INLINE_ATTRIBUTE
623 static void
add_traffic_rule_to_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)624 add_traffic_rule_to_nxctl(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
625 {
626 	struct nxctl_traffic_rule_storage *rs;
627 
628 	lck_mtx_lock(&nxctl->nxctl_lock);
629 	if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
630 		rs = nxctl_traffic_rule_storage_create();
631 		nxctl->nxctl_traffic_rule_storage = rs;
632 	}
633 	ntr->ntr_flags |= NTR_FLAG_ON_NXCTL_LIST;
634 	retain_traffic_rule(ntr);
635 	SLIST_INSERT_HEAD(&rs->rs_list, ntr, ntr_storage_link);
636 	rs->rs_count++;
637 	lck_mtx_unlock(&nxctl->nxctl_lock);
638 }
639 
640 SK_NO_INLINE_ATTRIBUTE
641 static void
remove_traffic_rule_from_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)642 remove_traffic_rule_from_nxctl(struct nxctl *nxctl,
643     struct nxctl_traffic_rule *ntr)
644 {
645 	struct nxctl_traffic_rule_storage *rs;
646 
647 	lck_mtx_lock(&nxctl->nxctl_lock);
648 	if ((ntr->ntr_flags & NTR_FLAG_ON_NXCTL_LIST) == 0) {
649 		lck_mtx_unlock(&nxctl->nxctl_lock);
650 		return;
651 	}
652 	rs = nxctl->nxctl_traffic_rule_storage;
653 	SLIST_REMOVE(&rs->rs_list, ntr, nxctl_traffic_rule, ntr_storage_link);
654 	rs->rs_count--;
655 	ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
656 	release_traffic_rule(ntr);
657 	if (rs->rs_count == 0) {
658 		nxctl_traffic_rule_storage_destroy(rs);
659 		nxctl->nxctl_traffic_rule_storage = NULL;
660 	}
661 	lck_mtx_unlock(&nxctl->nxctl_lock);
662 }
663 
664 static int
inet_traffic_rule_validate(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)665 inet_traffic_rule_validate(struct nxctl_traffic_rule_type *type,
666     const char *ifname,
667     struct ifnet_traffic_descriptor_common *td,
668     struct ifnet_traffic_rule_action *ra)
669 {
670 #pragma unused(type)
671 	char buf[IFNAMSIZ];
672 	int unit, i;
673 	struct ifnet_traffic_descriptor_inet *tdi;
674 	uint8_t mask = 0, ipver, proto;
675 
676 	if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
677 		SK_ERR("invalid ifname: %s", ifname);
678 		return EINVAL;
679 	}
680 	if (td->itd_len != sizeof(*tdi)) {
681 		SK_ERR("invalid td len: expected %d, actual %d",
682 		    sizeof(*tdi), td->itd_len);
683 		return EINVAL;
684 	}
685 	if (td->itd_flags == 0 ||
686 	    (td->itd_flags &
687 	    ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
688 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
689 		SK_ERR("invalid td flags: 0x%x", td->itd_flags);
690 		return EINVAL;
691 	}
692 	tdi = (struct ifnet_traffic_descriptor_inet *)td;
693 	for (i = 0; i < NINETRULEMASKS; i++) {
694 		if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
695 			mask = tdi->inet_mask;
696 			break;
697 		}
698 	}
699 	if (mask == 0) {
700 		SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
701 		return EINVAL;
702 	}
703 	ipver = tdi->inet_ipver;
704 	if (ipver != IPVERSION && ipver != IPV6_VERSION) {
705 		SK_ERR("invalid inet ipver: 0x%x", ipver);
706 		return EINVAL;
707 	}
708 	proto = tdi->inet_proto;
709 	if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
710 		SK_ERR("invalid inet proto: %d", proto);
711 		return EINVAL;
712 	}
713 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
714 		if (ipver == IPVERSION) {
715 			if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
716 				SK_ERR("inet laddr v4 cannot be unspecified");
717 				return EINVAL;
718 			}
719 		} else {
720 			if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
721 				SK_ERR("inet laddr v4 cannot be unspecified");
722 				return EINVAL;
723 			}
724 		}
725 	}
726 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
727 		if (ipver == IPVERSION) {
728 			if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
729 				SK_ERR("inet raddr v6 cannot be unspecified");
730 				return EINVAL;
731 			}
732 		} else {
733 			if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
734 				SK_ERR("inet raddr v6 cannot be unspecified");
735 				return EINVAL;
736 			}
737 		}
738 	}
739 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
740 		if (tdi->inet_lport == 0) {
741 			SK_ERR("inet lport cannot be unspecified");
742 			return EINVAL;
743 		}
744 	}
745 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
746 		if (tdi->inet_rport == 0) {
747 			SK_ERR("inet rport cannot be unspecified");
748 			return EINVAL;
749 		}
750 	}
751 	if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
752 		SK_ERR("invalid ra len: expected %d, actual %d",
753 		    sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
754 		return EINVAL;
755 	}
756 	return 0;
757 }
758 
759 SK_NO_INLINE_ATTRIBUTE
760 static struct nxctl_traffic_rule_inet_storage *
inet_traffic_rule_storage_create(void)761 inet_traffic_rule_storage_create(void)
762 {
763 	struct nxctl_traffic_rule_inet_storage *rs;
764 
765 	rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
766 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
767 	SLIST_INIT(&rs->ris_if_list);
768 	rs->ris_count = 0;
769 	return rs;
770 }
771 
772 SK_NO_INLINE_ATTRIBUTE
773 static void
inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage * rs)774 inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage *rs)
775 {
776 	ASSERT(rs->ris_count == 0);
777 	ASSERT(SLIST_EMPTY(&rs->ris_if_list));
778 	sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
779 }
780 
781 SK_NO_INLINE_ATTRIBUTE
782 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)783 inet_traffic_rule_if_create(const char *ifname)
784 {
785 	struct nxctl_traffic_rule_inet_if *rif;
786 	int i;
787 
788 	rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
789 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
790 	for (i = 0; i < NINETRULEMASKS; i++) {
791 		SLIST_INIT(&rif->rii_lists[i]);
792 	}
793 	strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
794 	rif->rii_count = 0;
795 	return rif;
796 }
797 
798 SK_NO_INLINE_ATTRIBUTE
799 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)800 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
801 {
802 	int i;
803 
804 	for (i = 0; i < NINETRULEMASKS; i++) {
805 		ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
806 	}
807 	ASSERT(rif->rii_count == 0);
808 	sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
809 }
810 
811 SK_NO_INLINE_ATTRIBUTE
812 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)813 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
814     uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
815 {
816 	struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
817 	struct ifnet_traffic_descriptor_inet *tdi0;
818 	uint8_t mask;
819 	boolean_t exact;
820 
821 	VERIFY(strcmp(ntr->ntr_ifname, ifname) == 0);
822 	tdi0 = &ntri->ntri_td;
823 
824 	exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
825 	mask = tdi0->inet_mask & tdi->inet_mask;
826 	if (exact) {
827 		ASSERT(tdi0->inet_mask == tdi->inet_mask);
828 	}
829 	ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
830 	if (tdi0->inet_ipver != tdi->inet_ipver) {
831 		DTRACE_SKYWALK2(ipver__mismatch,
832 		    uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
833 		return FALSE;
834 	}
835 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0 &&
836 	    tdi0->inet_proto != tdi->inet_proto) {
837 		DTRACE_SKYWALK2(proto__mismatch,
838 		    uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
839 		return FALSE;
840 	}
841 	if (tdi0->inet_ipver == IPVERSION) {
842 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
843 		    tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
844 			DTRACE_SKYWALK2(v4laddr__mismatch,
845 			    in_addr_t, tdi0->inet_laddr.iia_v4addr,
846 			    in_addr_t, tdi->inet_laddr.iia_v4addr);
847 			return FALSE;
848 		}
849 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
850 		    tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
851 			DTRACE_SKYWALK2(v4raddr__mismatch,
852 			    in_addr_t, tdi0->inet_raddr.iia_v4addr,
853 			    in_addr_t, tdi->inet_raddr.iia_v4addr);
854 			return FALSE;
855 		}
856 	} else {
857 		ASSERT(tdi0->inet_ipver == IPV6_VERSION);
858 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
859 		    !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
860 			DTRACE_SKYWALK2(v6laddr__mismatch,
861 			    struct in6_addr *, &tdi0->inet_laddr,
862 			    struct in6_addr *, &tdi->inet_laddr);
863 			return FALSE;
864 		}
865 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
866 		    !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
867 			DTRACE_SKYWALK2(v6raddr__mismatch,
868 			    struct in6_addr *, &tdi0->inet_raddr,
869 			    struct in6_addr *, &tdi->inet_raddr);
870 			return FALSE;
871 		}
872 	}
873 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
874 	    tdi0->inet_lport != tdi->inet_lport) {
875 		DTRACE_SKYWALK2(lport__mismatch,
876 		    uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
877 		return FALSE;
878 	}
879 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
880 	    tdi0->inet_rport != tdi->inet_rport) {
881 		DTRACE_SKYWALK2(rport__mismatch,
882 		    uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
883 		return FALSE;
884 	}
885 	return TRUE;
886 }
887 
888 static int
inet_traffic_rule_find(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)889 inet_traffic_rule_find(struct nxctl_traffic_rule_type *type, const char *ifname,
890     struct ifnet_traffic_descriptor_common *td, uint32_t flags,
891     struct nxctl_traffic_rule **ntrp)
892 {
893 	struct nxctl_traffic_rule_inet *ntri = NULL;
894 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
895 	struct nxctl_traffic_rule_inet_if *rif;
896 	struct ifnet_traffic_descriptor_inet *tdi =
897 	    (struct ifnet_traffic_descriptor_inet *)td;
898 	int i;
899 
900 	if (rs == NULL) {
901 		return ENOENT;
902 	}
903 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
904 		if (strcmp(rif->rii_ifname, ifname) != 0) {
905 			continue;
906 		}
907 		for (i = 0; i < NINETRULEMASKS; i++) {
908 			if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
909 			    tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
910 				continue;
911 			}
912 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
913 				if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
914 					*ntrp = (struct nxctl_traffic_rule *)ntri;
915 					return 0;
916 				}
917 			}
918 		}
919 	}
920 	return ENOENT;
921 }
922 
923 static int
inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type * type,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)924 inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type *type,
925     uuid_t uuid, struct nxctl_traffic_rule **ntrp)
926 {
927 	struct nxctl_traffic_rule_inet *ntri;
928 	struct nxctl_traffic_rule *ntr;
929 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
930 	struct nxctl_traffic_rule_inet_if *rif;
931 	int i;
932 
933 	if (rs == NULL) {
934 		return ENOENT;
935 	}
936 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
937 		for (i = 0; i < NINETRULEMASKS; i++) {
938 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
939 				ntr = &ntri->ntri_common;
940 				if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
941 					*ntrp = ntr;
942 					return 0;
943 				}
944 			}
945 		}
946 	}
947 	return ENOENT;
948 }
949 
950 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)951 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
952 {
953 	struct ifnet *ifp;
954 
955 	ifp = ifunit_ref(ifname);
956 	if (ifp == NULL) {
957 		DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
958 		return;
959 	}
960 	ifnet_update_traffic_rule_count(ifp, count);
961 	ifnet_decr_iorefcnt(ifp);
962 }
963 
964 static void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)965 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
966 {
967 	struct nxctl_traffic_rule_type *type = ntr->ntr_type;
968 	struct nxctl_traffic_rule_inet_storage *rs;
969 	struct nxctl_traffic_rule_inet_if *rif;
970 	struct nxctl_traffic_rule_inet *ntri =
971 	    (struct nxctl_traffic_rule_inet *)ntr;
972 	struct nxctl_traffic_rule_inet_head *list = NULL;
973 	int i;
974 
975 	if ((rs = type->ntrt_storage) == NULL) {
976 		rs = inet_traffic_rule_storage_create();
977 		type->ntrt_storage = rs;
978 	}
979 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
980 		if (strcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
981 			break;
982 		}
983 	}
984 	if (rif == NULL) {
985 		rif = inet_traffic_rule_if_create(ntr->ntr_ifname);
986 		SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
987 	}
988 	for (i = 0; i < NINETRULEMASKS; i++) {
989 		if (ntri->ntri_td.inet_mask ==
990 		    nxctl_inet_traffic_rule_masks[i]) {
991 			list = &rif->rii_lists[i];
992 			break;
993 		}
994 	}
995 	retain_traffic_rule(ntr);
996 	ASSERT(list != NULL);
997 	SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
998 	/* per-interface count */
999 	rif->rii_count++;
1000 	inet_update_ifnet_traffic_rule_count(rif->rii_ifname, rif->rii_count);
1001 
1002 	/* global count */
1003 	rs->ris_count++;
1004 }
1005 
1006 static void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)1007 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
1008 {
1009 	struct nxctl_traffic_rule_inet_storage *rs;
1010 	struct nxctl_traffic_rule_inet_if *rif;
1011 	struct nxctl_traffic_rule_inet *ntri =
1012 	    (struct nxctl_traffic_rule_inet *)ntr;
1013 	struct nxctl_traffic_rule_inet_head *list = NULL;
1014 	struct nxctl_traffic_rule_type *type;
1015 	int i;
1016 
1017 	type = ntr->ntr_type;
1018 	rs = type->ntrt_storage;
1019 	ASSERT(rs != NULL);
1020 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1021 		if (strcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
1022 			break;
1023 		}
1024 	}
1025 	ASSERT(rif != NULL);
1026 	for (i = 0; i < NINETRULEMASKS; i++) {
1027 		if (ntri->ntri_td.inet_mask ==
1028 		    nxctl_inet_traffic_rule_masks[i]) {
1029 			list = &rif->rii_lists[i];
1030 			break;
1031 		}
1032 	}
1033 	ASSERT(list != NULL);
1034 	SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
1035 	rif->rii_count--;
1036 	inet_update_ifnet_traffic_rule_count(rif->rii_ifname, rif->rii_count);
1037 
1038 	rs->ris_count--;
1039 	release_traffic_rule(ntr);
1040 
1041 	if (rif->rii_count == 0) {
1042 		SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
1043 		inet_traffic_rule_if_destroy(rif);
1044 	}
1045 	if (rs->ris_count == 0) {
1046 		type->ntrt_storage = NULL;
1047 		inet_traffic_rule_storage_destroy(rs);
1048 	}
1049 }
1050 
1051 /*
1052  * XXX
1053  * This may need additional changes to ensure safety against detach/attach.
1054  * This is not an issue for the first consumer of llink interfaces, cellular,
1055  * which does not detach.
1056  */
1057 static int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)1058 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
1059 {
1060 	struct ifnet *ifp;
1061 	struct nx_netif *nif;
1062 	struct netif_qset *qset = NULL;
1063 	struct nxctl_traffic_rule_inet *ntri;
1064 	int err = 0;
1065 
1066 	ifp = ifunit_ref(ntr->ntr_ifname);
1067 	if (ifp == NULL) {
1068 		DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
1069 		err = ENXIO;
1070 		goto done;
1071 	}
1072 	nif = NA(ifp)->nifna_netif;
1073 	if (!NX_LLINK_PROV(nif->nif_nx)) {
1074 		DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
1075 		err = ENOTSUP;
1076 		goto done;
1077 	}
1078 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1079 	qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
1080 	err = nx_netif_notify_steering_info(nif, qset,
1081 	    (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
1082 	    ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
1083 done:
1084 	if (qset != NULL) {
1085 		nx_netif_qset_release(&qset);
1086 	}
1087 	if (ifp != NULL) {
1088 		ifnet_decr_iorefcnt(ifp);
1089 	}
1090 	return err;
1091 }
1092 
1093 static int
inet_traffic_rule_create(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1094 inet_traffic_rule_create(struct nxctl_traffic_rule_type *type,
1095     const char *ifname, struct ifnet_traffic_descriptor_common *td,
1096     struct ifnet_traffic_rule_action *ra, uint32_t flags,
1097     struct nxctl_traffic_rule **ntrp)
1098 {
1099 	struct nxctl_traffic_rule_inet *ntri;
1100 	struct nxctl_traffic_rule *ntr;
1101 
1102 	ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
1103 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
1104 	ntr = &ntri->ntri_common;
1105 
1106 	ntr->ntr_type = type;
1107 	ntr->ntr_flags = flags;
1108 	uuid_generate(ntr->ntr_uuid);
1109 	os_ref_init(&ntr->ntr_refcnt, NULL);
1110 
1111 	strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
1112 	proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
1113 	bcopy(td, &ntri->ntri_td, sizeof(ntri->ntri_td));
1114 	bcopy(ra, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
1115 
1116 	*ntrp = ntr;
1117 	return 0;
1118 }
1119 
1120 static void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)1121 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
1122 {
1123 	struct nxctl_traffic_rule_inet *ntri;
1124 
1125 	ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
1126 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
1127 	sk_free_type(struct nxctl_traffic_rule_inet, ntri);
1128 }
1129 
1130 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)1131 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
1132     struct nxctl_traffic_rule_inet_iocinfo *info)
1133 {
1134 	struct nxctl_traffic_rule *ntr;
1135 	struct nxctl_traffic_rule_generic_iocinfo *ginfo;
1136 
1137 	bzero(info, sizeof(*info));
1138 	ntr = &ntri->ntri_common;
1139 	ginfo = &info->tri_common;
1140 	_CASSERT(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
1141 	_CASSERT(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
1142 	uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
1143 	strlcpy(ginfo->trg_procname, ntr->ntr_procname,
1144 	    sizeof(ginfo->trg_procname));
1145 	strlcpy(ginfo->trg_ifname, ntr->ntr_ifname,
1146 	    sizeof(ginfo->trg_ifname));
1147 	bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
1148 	bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
1149 }
1150 
1151 static int
inet_traffic_rule_get_all(struct nxctl_traffic_rule_type * type,uint32_t size,uint32_t * count,user_addr_t uaddr)1152 inet_traffic_rule_get_all(struct nxctl_traffic_rule_type *type, uint32_t size,
1153     uint32_t *count, user_addr_t uaddr)
1154 {
1155 	struct nxctl_traffic_rule_inet *ntri = NULL;
1156 	struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
1157 	struct nxctl_traffic_rule_inet_if *rif;
1158 	struct nxctl_traffic_rule_inet_iocinfo info;
1159 	int i, err;
1160 
1161 	if (size != sizeof(info)) {
1162 		SK_ERR("size: actual %d, expected %d", size, sizeof(info));
1163 		return EINVAL;
1164 	}
1165 	if (rs == NULL) {
1166 		*count = 0;
1167 		return 0;
1168 	}
1169 	if (*count < rs->ris_count) {
1170 		SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
1171 		return ENOBUFS;
1172 	}
1173 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1174 		for (i = 0; i < NINETRULEMASKS; i++) {
1175 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
1176 				convert_ntri_to_iocinfo(ntri, &info);
1177 				err = copyout(&info, uaddr, sizeof(info));
1178 				if (err != 0) {
1179 					SK_ERR("copyout failed: %d", err);
1180 					return err;
1181 				}
1182 				uaddr += sizeof(info);
1183 			}
1184 		}
1185 	}
1186 	*count = rs->ris_count;
1187 	return 0;
1188 }
1189 
1190 SK_NO_INLINE_ATTRIBUTE
1191 static void
retain_traffic_rule(struct nxctl_traffic_rule * ntr)1192 retain_traffic_rule(struct nxctl_traffic_rule *ntr)
1193 {
1194 #if (DEVELOPMENT || DEBUG)
1195 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1196 	DTRACE_SKYWALK2(ntr__retain, struct nxctl_traffic_rule *, ntr,
1197 	    os_ref_count_t, count);
1198 #endif
1199 	os_ref_retain(&ntr->ntr_refcnt);
1200 }
1201 
1202 SK_NO_INLINE_ATTRIBUTE
1203 static void
release_traffic_rule(struct nxctl_traffic_rule * ntr)1204 release_traffic_rule(struct nxctl_traffic_rule *ntr)
1205 {
1206 #if (DEVELOPMENT || DEBUG)
1207 	os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1208 	DTRACE_SKYWALK2(ntr__release, struct nxctl_traffic_rule *, ntr,
1209 	    os_ref_count_t, count);
1210 #endif
1211 	if (os_ref_release(&ntr->ntr_refcnt) == 0) {
1212 		ntr->ntr_type->ntrt_destroy(ntr);
1213 	}
1214 }
1215 
1216 SK_NO_INLINE_ATTRIBUTE
1217 static int
notify_traffic_rule(struct nxctl_traffic_rule * ntr,uint32_t flags)1218 notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags)
1219 {
1220 	return ntr->ntr_type->ntrt_notify(ntr, flags);
1221 }
1222 
1223 static void
link_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1224 link_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1225 {
1226 	/*
1227 	 * The persist flag means: do not clean up rule upon nxctl fd close.
1228 	 * This means we only add the rule to the nxctl list if persist
1229 	 * is not set.
1230 	 */
1231 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1232 		add_traffic_rule_to_nxctl(nxctl, ntr);
1233 	}
1234 	ntr->ntr_type->ntrt_link(ntr);
1235 }
1236 
1237 static void
unlink_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1238 unlink_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1239 {
1240 	if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1241 		remove_traffic_rule_from_nxctl(nxctl, ntr);
1242 	}
1243 	ntr->ntr_type->ntrt_unlink(ntr);
1244 }
1245 
1246 static int
find_traffic_rule_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1247 find_traffic_rule_by_uuid(uuid_t uuid, struct nxctl_traffic_rule **ntrp)
1248 {
1249 	int i, err;
1250 	struct nxctl_traffic_rule_type *ntrt;
1251 	struct nxctl_traffic_rule *ntr = NULL;
1252 
1253 	for (i = 0; i < NRULETYPES; i++) {
1254 		ntrt = &nxctl_rule_types[i];
1255 		err = ntrt->ntrt_find_by_uuid(ntrt, uuid, &ntr);
1256 		if (err == 0) {
1257 			ASSERT(ntr != NULL);
1258 			*ntrp = ntr;
1259 			return 0;
1260 		}
1261 	}
1262 	return ENOENT;
1263 }
1264 
1265 static struct nxctl_traffic_rule_type *
find_traffic_rule_type(uint8_t type)1266 find_traffic_rule_type(uint8_t type)
1267 {
1268 	int i;
1269 	struct nxctl_traffic_rule_type *ntrt;
1270 
1271 	for (i = 0; i < NRULETYPES; i++) {
1272 		ntrt = &nxctl_rule_types[i];
1273 		if (ntrt->ntrt_type == type) {
1274 			return ntrt;
1275 		}
1276 	}
1277 	return NULL;
1278 }
1279 
1280 SK_NO_INLINE_ATTRIBUTE
1281 static int
add_traffic_rule(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1282 add_traffic_rule(struct nxctl *nxctl, const char *ifname,
1283     struct ifnet_traffic_descriptor_common *td,
1284     struct ifnet_traffic_rule_action *ra,
1285     uint32_t flags,
1286     struct nxctl_traffic_rule **ntrp)
1287 {
1288 	struct nxctl_traffic_rule_type *type = NULL;
1289 	struct nxctl_traffic_rule *ntr = NULL;
1290 	int err;
1291 
1292 	NXTR_WLOCK();
1293 	type = find_traffic_rule_type(td->itd_type);
1294 	if (type == NULL) {
1295 		SK_ERR("rule type %x not found", td->itd_type);
1296 		err = EINVAL;
1297 		goto fail;
1298 	}
1299 	err = type->ntrt_validate(type, ifname, td, ra);
1300 	if (err != 0) {
1301 		SK_ERR("rule validate failed: %d", err);
1302 		goto fail;
1303 	}
1304 	err = type->ntrt_find(type, ifname, td, NTR_FIND_FLAG_EXACT, &ntr);
1305 	if (err == 0) {
1306 		SK_ERR("rule already exists");
1307 		ASSERT(ntr != NULL);
1308 		err = EEXIST;
1309 		goto fail;
1310 	} else if (err != ENOENT) {
1311 		SK_ERR("rule find failed: %d", err);
1312 		goto fail;
1313 	}
1314 	err = type->ntrt_create(type, ifname, td, ra, flags, &ntr);
1315 	if (err != 0) {
1316 		SK_ERR("rule create failed: %d", err);
1317 		goto fail;
1318 	}
1319 	link_traffic_rule(nxctl, ntr);
1320 	if (ntrp != NULL) {
1321 		retain_traffic_rule(ntr);
1322 		*ntrp = ntr;
1323 	}
1324 	NXTR_WUNLOCK();
1325 	return 0;
1326 fail:
1327 	NXTR_WUNLOCK();
1328 	return err;
1329 }
1330 
1331 
1332 SK_NO_INLINE_ATTRIBUTE
1333 static int
remove_traffic_rule(struct nxctl * nxctl,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1334 remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
1335     struct nxctl_traffic_rule **ntrp)
1336 {
1337 	struct nxctl_traffic_rule *ntr;
1338 	int err;
1339 
1340 	NXTR_WLOCK();
1341 	err = find_traffic_rule_by_uuid(uuid, &ntr);
1342 	if (err != 0) {
1343 		SK_ERR("traffic rule not found");
1344 		NXTR_WUNLOCK();
1345 		return err;
1346 	}
1347 	if (ntrp != NULL) {
1348 		retain_traffic_rule(ntr);
1349 		*ntrp = ntr;
1350 	}
1351 	unlink_traffic_rule(nxctl, ntr);
1352 	/* release initial reference */
1353 	release_traffic_rule(ntr);
1354 	NXTR_WUNLOCK();
1355 	return 0;
1356 }
1357 
1358 static uint32_t
convert_traffic_rule_ioc_flags(uint32_t flags)1359 convert_traffic_rule_ioc_flags(uint32_t flags)
1360 {
1361 	uint32_t f = 0;
1362 
1363 	if ((flags & NXIOC_ADD_TRAFFIC_RULE_FLAG_PERSIST) != 0) {
1364 		f |= NTR_FLAG_PERSIST;
1365 	}
1366 	return f;
1367 }
1368 
1369 SK_NO_INLINE_ATTRIBUTE
1370 static int
add_traffic_rule_generic(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,uuid_t * uuid)1371 add_traffic_rule_generic(struct nxctl *nxctl, const char *ifname,
1372     struct ifnet_traffic_descriptor_common *td,
1373     struct ifnet_traffic_rule_action *ra, uint32_t flags, uuid_t *uuid)
1374 {
1375 	struct nxctl_traffic_rule *ntr;
1376 	int err;
1377 
1378 	err = add_traffic_rule(nxctl, ifname, td, ra, flags, &ntr);
1379 	if (err != 0) {
1380 		return err;
1381 	}
1382 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_ADD);
1383 	uuid_copy(*uuid, ntr->ntr_uuid);
1384 	release_traffic_rule(ntr);
1385 	return 0;
1386 }
1387 
1388 int
nxioctl_add_traffic_rule_inet(struct nxctl * nxctl,caddr_t data,proc_t procp)1389 nxioctl_add_traffic_rule_inet(struct nxctl *nxctl, caddr_t data, proc_t procp)
1390 {
1391 #pragma unused(procp)
1392 	struct nxctl_add_traffic_rule_inet_iocargs *args =
1393 	    (struct nxctl_add_traffic_rule_inet_iocargs *)(void *)data;
1394 
1395 	return add_traffic_rule_generic(nxctl, args->atri_ifname,
1396 	           (struct ifnet_traffic_descriptor_common *)&args->atri_td,
1397 	           (struct ifnet_traffic_rule_action *)&args->atri_ra,
1398 	           convert_traffic_rule_ioc_flags(args->atri_flags),
1399 	           &args->atri_uuid);
1400 }
1401 
1402 int
nxioctl_remove_traffic_rule(struct nxctl * nxctl,caddr_t data,proc_t procp)1403 nxioctl_remove_traffic_rule(struct nxctl *nxctl, caddr_t data, proc_t procp)
1404 {
1405 #pragma unused(procp)
1406 	struct nxctl_remove_traffic_rule_iocargs *args =
1407 	    (struct nxctl_remove_traffic_rule_iocargs *)(void *)data;
1408 	struct nxctl_traffic_rule *ntr;
1409 	int err;
1410 
1411 	err = remove_traffic_rule(nxctl, args->rtr_uuid, &ntr);
1412 	if (err != 0) {
1413 		return err;
1414 	}
1415 	(void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
1416 	release_traffic_rule(ntr);
1417 	return 0;
1418 }
1419 
1420 int
nxioctl_get_traffic_rules(struct nxctl * nxctl,caddr_t data,proc_t procp)1421 nxioctl_get_traffic_rules(struct nxctl *nxctl, caddr_t data, proc_t procp)
1422 {
1423 #pragma unused(nxctl)
1424 	struct nxctl_get_traffic_rules_iocargs *args =
1425 	    (struct nxctl_get_traffic_rules_iocargs *)(void *)data;
1426 	struct nxctl_traffic_rule_type *type;
1427 	user_addr_t uaddr;
1428 	int err;
1429 
1430 	NXTR_RLOCK();
1431 	type = find_traffic_rule_type(args->gtr_type);
1432 	if (type == NULL) {
1433 		SK_ERR("rule type %x not found", args->gtr_type);
1434 		err = EINVAL;
1435 		goto fail;
1436 	}
1437 	uaddr = proc_is64bit(procp) ? args->gtr_buf64 :
1438 	    CAST_USER_ADDR_T(args->gtr_buf);
1439 	err = type->ntrt_get_all(type, args->gtr_size, &args->gtr_count, uaddr);
1440 	if (err != 0) {
1441 		goto fail;
1442 	}
1443 	NXTR_RUNLOCK();
1444 	return 0;
1445 fail:
1446 	NXTR_RUNLOCK();
1447 	return err;
1448 }
1449