1 /*
2 * Copyright (c) 2022 Apple Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28 #include <skywalk/os_skywalk_private.h>
29 #include <skywalk/nexus/netif/nx_netif.h>
30 #include <netinet/ip.h>
31 #include <netinet/ip6.h>
32 #include <netinet/tcp.h>
33 #include <netinet/udp.h>
34 #include <sys/sdt.h>
35
36 /*
37 * Implementation of nexus traffic rules APIs.
38 */
39
40 struct nxctl_traffic_rule_type;
41 struct nxctl_traffic_rule;
42
43 /*
44 * These callbacks need to be implemented for each rule type.
45 */
46
47 /* Validate user provided parameters. */
48 typedef int (nxctl_traffic_rule_validate_cb_t)(
49 struct nxctl_traffic_rule_type *type,
50 const char *ifname,
51 struct ifnet_traffic_descriptor_common *td,
52 struct ifnet_traffic_rule_action *ra);
53 /*
54 * Each rule type has its own global structure for storing rules.
55 * These callbacks access this global structure.
56 */
57 #define NTR_FIND_FLAG_EXACT 0x0001
58 typedef int (nxctl_traffic_rule_find_cb_t)(
59 struct nxctl_traffic_rule_type *type,
60 const char *ifname,
61 struct ifnet_traffic_descriptor_common *td,
62 uint32_t flags,
63 struct nxctl_traffic_rule **ntrp);
64
65 typedef int (nxctl_traffic_rule_find_by_uuid_cb_t)(
66 struct nxctl_traffic_rule_type *type,
67 uuid_t uuid,
68 struct nxctl_traffic_rule **ntrp);
69
70 typedef void (nxctl_traffic_rule_link_cb_t)(
71 struct nxctl_traffic_rule *ntr);
72
73 typedef void (nxctl_traffic_rule_unlink_cb_t)(
74 struct nxctl_traffic_rule *ntr);
75
76 /*
77 * Notifies lower layers of the addition/removal of a rule.
78 * This is called outside of nxctl_traffic_rule_lock to avoid potential
79 * locking issues.
80 */
81 #define NTR_NOTIFY_FLAG_ADD 0x0001
82 #define NTR_NOTIFY_FLAG_REMOVE 0x0002
83 typedef int (nxctl_traffic_rule_notify_cb_t)(
84 struct nxctl_traffic_rule *ntr,
85 uint32_t flags);
86
87 /*
88 * Create/Destroy callbacks for a rule type.
89 */
90 typedef int (nxctl_traffic_rule_create_cb_t)(
91 struct nxctl_traffic_rule_type *type,
92 const char *ifname,
93 struct ifnet_traffic_descriptor_common *td,
94 struct ifnet_traffic_rule_action *ra,
95 uint32_t flags,
96 struct nxctl_traffic_rule **ntrp);
97
98 typedef void (nxctl_traffic_rule_destroy_cb_t)(
99 struct nxctl_traffic_rule *ntr);
100
101 /*
102 * This is used for copying all rules for a type (including generic
103 * and type-specific info) to userspace.
104 */
105 typedef int (nxctl_traffic_rule_get_all_cb_t)(
106 struct nxctl_traffic_rule_type *type,
107 uint32_t size,
108 uint32_t *count,
109 user_addr_t uaddr);
110
111 struct nxctl_traffic_rule_type {
112 uint8_t ntrt_type;
113 nxctl_traffic_rule_validate_cb_t *ntrt_validate;
114 nxctl_traffic_rule_find_cb_t *ntrt_find;
115 nxctl_traffic_rule_find_by_uuid_cb_t *ntrt_find_by_uuid;
116 nxctl_traffic_rule_link_cb_t *ntrt_link;
117 nxctl_traffic_rule_unlink_cb_t *ntrt_unlink;
118 nxctl_traffic_rule_notify_cb_t *ntrt_notify;
119 nxctl_traffic_rule_create_cb_t *ntrt_create;
120 nxctl_traffic_rule_destroy_cb_t *ntrt_destroy;
121 nxctl_traffic_rule_get_all_cb_t *ntrt_get_all;
122 /*
123 * -fbounds-safety: Why is this void *? All usage of this was for
124 * struct nxctl_traffic_rule_inet_storage *
125 */
126 struct nxctl_traffic_rule_inet_storage *ntrt_storage;
127 };
128
129 static nxctl_traffic_rule_validate_cb_t inet_traffic_rule_validate;
130 static nxctl_traffic_rule_find_cb_t inet_traffic_rule_find;
131 static nxctl_traffic_rule_find_by_uuid_cb_t inet_traffic_rule_find_by_uuid;
132 static nxctl_traffic_rule_link_cb_t inet_traffic_rule_link;
133 static nxctl_traffic_rule_unlink_cb_t inet_traffic_rule_unlink;
134 static nxctl_traffic_rule_notify_cb_t inet_traffic_rule_notify;
135 static nxctl_traffic_rule_create_cb_t inet_traffic_rule_create;
136 static nxctl_traffic_rule_destroy_cb_t inet_traffic_rule_destroy;
137 static nxctl_traffic_rule_get_all_cb_t inet_traffic_rule_get_all;
138
139 static struct nxctl_traffic_rule_type nxctl_rule_types[] = {
140 {
141 .ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET,
142 .ntrt_validate = inet_traffic_rule_validate,
143 .ntrt_find = inet_traffic_rule_find,
144 .ntrt_find_by_uuid = inet_traffic_rule_find_by_uuid,
145 .ntrt_link = inet_traffic_rule_link,
146 .ntrt_unlink = inet_traffic_rule_unlink,
147 .ntrt_notify = inet_traffic_rule_notify,
148 .ntrt_create = inet_traffic_rule_create,
149 .ntrt_destroy = inet_traffic_rule_destroy,
150 .ntrt_get_all = inet_traffic_rule_get_all,
151 },
152 };
153 #define NRULETYPES \
154 (sizeof(nxctl_rule_types)/sizeof(struct nxctl_traffic_rule_type))
155
156 /*
157 * Generic traffic rule.
158 * Contains fields common to all traffic rules.
159 */
160 #define NTR_FLAG_PERSIST 0x0001
161 #define NTR_FLAG_ON_NXCTL_LIST 0x0002
162 struct nxctl_traffic_rule {
163 struct nxctl_traffic_rule_type *ntr_type;
164 uint32_t ntr_flags;
165 os_refcnt_t ntr_refcnt;
166 uuid_t ntr_uuid;
167 char ntr_procname[NTR_PROCNAME_SZ];
168 char ntr_ifname[IFNAMSIZ];
169 SLIST_ENTRY(nxctl_traffic_rule) ntr_storage_link;
170 };
171
172 /*
173 * Inet-specific traffic rule.
174 */
175 struct nxctl_traffic_rule_inet {
176 struct nxctl_traffic_rule ntri_common;
177 SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
178 struct ifnet_traffic_descriptor_inet ntri_td;
179 struct ifnet_traffic_rule_action_steer ntri_ra;
180 };
181
182 /*
183 * Currently supported tuple types.
184 */
185 #define ITDBIT(set, bit) (((set) != 0) ? (bit) : 0)
186 #define ITRM(proto, laddr, raddr, lport, rport) \
187 (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER | \
188 ITDBIT(proto, IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) | \
189 ITDBIT(laddr, IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) | \
190 ITDBIT(raddr, IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) | \
191 ITDBIT(lport, IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) | \
192 ITDBIT(rport, IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT))
193
194 static uint8_t nxctl_inet_traffic_rule_masks[] = {
195 ITRM(1, 1, 1, 1, 1),
196 ITRM(1, 1, 1, 1, 0),
197 ITRM(1, 1, 1, 0, 1),
198 ITRM(1, 1, 1, 0, 0),
199 ITRM(1, 1, 0, 1, 1),
200 ITRM(1, 1, 0, 1, 0),
201 ITRM(1, 1, 0, 0, 1),
202 ITRM(1, 1, 0, 0, 0),
203 ITRM(1, 0, 1, 1, 1),
204 ITRM(1, 0, 1, 1, 0),
205 ITRM(1, 0, 1, 0, 1),
206 ITRM(1, 0, 1, 0, 0),
207 ITRM(1, 0, 0, 1, 1),
208 ITRM(1, 0, 0, 1, 0),
209 ITRM(1, 0, 0, 0, 1),
210 // ITRM(1, 0, 0, 0, 0), addr or port is required
211 ITRM(0, 1, 1, 1, 1),
212 ITRM(0, 1, 1, 1, 0),
213 ITRM(0, 1, 1, 0, 1),
214 ITRM(0, 1, 1, 0, 0),
215 ITRM(0, 1, 0, 1, 1),
216 ITRM(0, 1, 0, 1, 0),
217 ITRM(0, 1, 0, 0, 1),
218 ITRM(0, 1, 0, 0, 0),
219 ITRM(0, 0, 1, 1, 1),
220 ITRM(0, 0, 1, 1, 0),
221 ITRM(0, 0, 1, 0, 1),
222 ITRM(0, 0, 1, 0, 0),
223 ITRM(0, 0, 0, 1, 1),
224 ITRM(0, 0, 0, 1, 0),
225 ITRM(0, 0, 0, 0, 1),
226 // ITRM(0, 0, 0, 0, 0),
227 };
228 #define NINETRULEMASKS \
229 (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
230
231 /* Per-interface lists of traffic rules */
232 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
233 struct nxctl_traffic_rule_inet_if {
234 char rii_ifname[IFNAMSIZ];
235 struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
236 uint32_t rii_count;
237 SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
238 };
239
240 /* List of per-interface lists */
241 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
242 struct nxctl_traffic_rule_inet_storage {
243 struct nxctl_traffic_rule_inet_if_head ris_if_list;
244 uint32_t ris_count;
245 };
246
247 /* Per-fd list kept at the nxctl */
248 SLIST_HEAD(nxctl_traffic_rule_head, nxctl_traffic_rule);
249 struct nxctl_traffic_rule_storage {
250 struct nxctl_traffic_rule_head rs_list;
251 uint32_t rs_count;
252 };
253
254 static LCK_RW_DECLARE_ATTR(nxctl_traffic_rule_lock, &sk_lock_group, &sk_lock_attr);
255 #define NXTR_WLOCK() \
256 lck_rw_lock_exclusive(&nxctl_traffic_rule_lock)
257 #define NXTR_WUNLOCK() \
258 lck_rw_unlock_exclusive(&nxctl_traffic_rule_lock)
259 #define NXTR_RLOCK() \
260 lck_rw_lock_shared(&nxctl_traffic_rule_lock)
261 #define NXTR_RUNLOCK() \
262 lck_rw_unlock_shared(&nxctl_traffic_rule_lock)
263
264 static struct nxctl_traffic_rule_type *find_traffic_rule_type(uint8_t type);
265 static void retain_traffic_rule(struct nxctl_traffic_rule *ntr);
266 static void release_traffic_rule(struct nxctl_traffic_rule *ntr);
267 static int remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
268 struct nxctl_traffic_rule **ntrp);
269 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
270 struct ifnet_ip_addr *a2);
271 static int notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags);
272
273 #define NXCTL_TRAFFIC_RULE_TAG "com.apple.skywalk.nexus.traffic_rule"
274 static kern_allocation_name_t nxctl_traffic_rule_tag;
275 static struct nxctl_traffic_rule_type *inet_traffic_rule_type = NULL;
276
277 /*
278 * If an interface attaches after rule(s) are added, this function is used
279 * retrieve the current rule count for that interface.
280 */
281 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)282 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
283 {
284 struct nxctl_traffic_rule_inet_storage *rs;
285 struct nxctl_traffic_rule_inet_if *rif;
286 int err;
287
288 NXTR_RLOCK();
289 rs = inet_traffic_rule_type->ntrt_storage;
290 if (rs == NULL) {
291 err = ENOENT;
292 goto fail;
293 }
294 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
295 if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) == 0) {
296 break;
297 }
298 }
299 if (rif == NULL) {
300 err = ENOENT;
301 goto fail;
302 }
303 *count = rif->rii_count;
304 NXTR_RUNLOCK();
305 return 0;
306 fail:
307 NXTR_RUNLOCK();
308 return err;
309 }
310
311 /*
312 * Used for finding the qset id associated with a traffic descriptor.
313 */
314 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)315 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
316 struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
317 {
318 struct nxctl_traffic_rule_inet *__single ntri = NULL;
319 struct nxctl_traffic_rule *__single ntr = NULL;
320 int err;
321
322 NXTR_RLOCK();
323 ASSERT(inet_traffic_rule_type != NULL);
324 err = inet_traffic_rule_type->ntrt_find(inet_traffic_rule_type, ifname,
325 (struct ifnet_traffic_descriptor_common *)td, 0, &ntr);
326 if (err != 0) {
327 SK_ERR("rule find failed: %d", err);
328 goto fail;
329 }
330 ntri = __container_of(ntr, struct nxctl_traffic_rule_inet, ntri_common);
331 *qset_id = ntri->ntri_ra.ras_qset_id;
332 NXTR_RUNLOCK();
333 return 0;
334 fail:
335 NXTR_RUNLOCK();
336 return err;
337 }
338
339 /*
340 * Based on flow_pkt_classify().
341 * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
342 */
343 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)344 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
345 {
346 union {
347 volatile struct ip *__indexable _iph;
348 volatile struct ip6_hdr *__indexable _ip6;
349 } _l3;
350 #define iph _l3._iph
351 #define ip6 _l3._ip6
352 union {
353 volatile struct tcphdr *_tcph;
354 volatile struct udphdr *_udph;
355 } _l4;
356 #define tcph _l4._tcph
357 #define udph _l4._udph
358 uint8_t *pkt_buf, *l3_hdr;
359 uint32_t bdlen, bdlim, bdoff, cls_len;
360 size_t pkt_len;
361 uint8_t ipv, l3hlen = 0; /* IP header length */
362 uint16_t l3tlen = 0; /* total length of IP packet */
363 uint8_t l4hlen = 0; /* TCP/UDP header length */
364 uint16_t ulen = 0; /* user data length */
365 int err;
366
367 ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
368 pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
369
370 MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
371 cls_len = bdlim - bdoff;
372 cls_len -= pkt->pkt_l2_len;
373 cls_len = (uint32_t)MIN(cls_len, pkt_len);
374 VERIFY(pkt_len >= cls_len);
375 if (cls_len == 0) {
376 SK_ERR("cls_len == 0");
377 err = EINVAL;
378 goto fail;
379 }
380 l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
381 iph = (volatile struct ip *)(void *)l3_hdr;
382 ipv = iph->ip_v;
383
384 switch (ipv) {
385 case 4:
386 if (cls_len < sizeof(struct ip)) {
387 SK_ERR("cls_len < sizeof(struct ip) (%d < %d)",
388 cls_len, sizeof(struct ip));
389 err = EINVAL;
390 goto fail;
391 }
392 l3hlen = (uint8_t)(iph->ip_hl << 2);
393 if (l3hlen < sizeof(struct ip)) {
394 SK_ERR("l3hlen < sizeof(struct ip) (%d < %d)",
395 l3hlen, sizeof(struct ip));
396 err = EINVAL;
397 goto fail;
398 }
399 if (cls_len < l3hlen) {
400 SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
401 err = EINVAL;
402 goto fail;
403 }
404 l3tlen = ntohs(iph->ip_len);
405 if (l3tlen < l3hlen) {
406 SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
407 err = EINVAL;
408 goto fail;
409 }
410 if (pkt_len < l3tlen) {
411 SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
412 err = EINVAL;
413 goto fail;
414 }
415 td->inet_ipver = IPVERSION;
416 td->inet_proto = iph->ip_p;
417 bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
418 sizeof(iph->ip_src));
419 bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
420 sizeof(iph->ip_dst));
421 break;
422 case 6:
423 l3hlen = sizeof(struct ip6_hdr);
424 if (cls_len < l3hlen) {
425 SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
426 err = EINVAL;
427 goto fail;
428 }
429 l3tlen = l3hlen + ntohs(ip6->ip6_plen);
430 if (pkt_len < l3tlen) {
431 SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
432 err = EINVAL;
433 goto fail;
434 }
435 td->inet_ipver = IPV6_VERSION;
436 td->inet_proto = ip6->ip6_nxt;
437 bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
438 sizeof(ip6->ip6_src));
439 bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
440 sizeof(ip6->ip6_dst));
441 break;
442 default:
443 SK_ERR("ipv == %d", ipv);
444 err = EINVAL;
445 goto fail;
446 }
447 tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
448 ulen = (l3tlen - l3hlen);
449 if (td->inet_proto == IPPROTO_TCP) {
450 if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
451 SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
452 "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*tcph),
453 ulen, sizeof(*tcph));
454 err = EINVAL;
455 goto fail;
456 }
457 l4hlen = (uint8_t)(tcph->th_off << 2);
458 if (l4hlen < sizeof(*tcph)) {
459 SK_ERR("l4hlen < sizeof(*tcph) (%d < %d)", l4hlen, sizeof(*tcph));
460 err = EINVAL;
461 goto fail;
462 }
463 if (l4hlen > ulen) {
464 SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
465 err = EINVAL;
466 goto fail;
467 }
468 bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
469 sizeof(td->inet_lport));
470 bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
471 sizeof(td->inet_rport));
472 } else if (td->inet_proto == IPPROTO_UDP) {
473 if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
474 SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
475 "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*udph),
476 ulen, sizeof(*udph));
477 err = EINVAL;
478 goto fail;
479 }
480 l4hlen = sizeof(*udph);
481 if (l4hlen > ulen) {
482 SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
483 err = EINVAL;
484 goto fail;
485 }
486 bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
487 sizeof(td->inet_lport));
488 bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
489 sizeof(td->inet_rport));
490 } else {
491 err = ENOTSUP;
492 goto fail;
493 }
494
495 td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
496 td->inet_common.itd_len = sizeof(*td);
497 td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
498 IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
499 td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
500 IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
501 IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
502 IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
503 IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
504 IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
505 return 0;
506 fail:
507 DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
508 uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
509 int, err);
510 bzero(td, sizeof(*td));
511 return err;
512 #undef iph
513 #undef ip6
514 #undef tcph
515 #undef udph
516 }
517
518 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)519 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
520 struct __kern_packet *pkt, uint64_t *qset_id)
521 {
522 struct ifnet_traffic_descriptor_inet td;
523 int err;
524
525 err = fill_inet_td(pkt, &td);
526 if (err != 0) {
527 return err;
528 }
529 return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
530 }
531
532 void
nxctl_traffic_rule_init(void)533 nxctl_traffic_rule_init(void)
534 {
535 ASSERT(nxctl_traffic_rule_tag == NULL);
536 nxctl_traffic_rule_tag =
537 kern_allocation_name_allocate(NXCTL_TRAFFIC_RULE_TAG, 0);
538 ASSERT(nxctl_traffic_rule_tag != NULL);
539
540 ASSERT(inet_traffic_rule_type == NULL);
541 inet_traffic_rule_type =
542 find_traffic_rule_type(IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET);
543 ASSERT(inet_traffic_rule_type != NULL);
544 }
545
546 void
nxctl_traffic_rule_fini(void)547 nxctl_traffic_rule_fini(void)
548 {
549 if (nxctl_traffic_rule_tag != NULL) {
550 kern_allocation_name_release(nxctl_traffic_rule_tag);
551 nxctl_traffic_rule_tag = NULL;
552 }
553 inet_traffic_rule_type = NULL;
554 }
555
556 static struct ifnet_ip_addr v6_zeros_addr = {0};
557 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)558 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
559 {
560 return memcmp(a1, a2, sizeof(*a1)) == 0;
561 }
562
563 SK_NO_INLINE_ATTRIBUTE
564 static struct nxctl_traffic_rule_storage *
nxctl_traffic_rule_storage_create(void)565 nxctl_traffic_rule_storage_create(void)
566 {
567 struct nxctl_traffic_rule_storage *rs;
568
569 rs = sk_alloc_type(struct nxctl_traffic_rule_storage,
570 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
571 SLIST_INIT(&rs->rs_list);
572 rs->rs_count = 0;
573 return rs;
574 }
575
576 SK_NO_INLINE_ATTRIBUTE
577 static void
nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage * rs)578 nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage *rs)
579 {
580 ASSERT(rs->rs_count == 0);
581 ASSERT(SLIST_EMPTY(&rs->rs_list));
582 sk_free_type(struct nxctl_traffic_rule_storage, rs);
583 }
584
585 /*
586 * This is meant to be called during closure of the nxctl's fd.
587 * This will cleanup all rules linked to this nxctl. Rules that
588 * are marked persistent won't be added to the nxctl list.
589 */
590 void
nxctl_traffic_rule_clean(struct nxctl * nxctl)591 nxctl_traffic_rule_clean(struct nxctl *nxctl)
592 {
593 struct nxctl_traffic_rule_storage *rs;
594 struct nxctl_traffic_rule *ntr, *next;
595 int err;
596
597 lck_mtx_lock(&nxctl->nxctl_lock);
598 if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
599 lck_mtx_unlock(&nxctl->nxctl_lock);
600 return;
601 }
602 ntr = SLIST_FIRST(&rs->rs_list);
603 SLIST_INIT(&rs->rs_list);
604 rs->rs_count = 0;
605 nxctl_traffic_rule_storage_destroy(rs);
606 nxctl->nxctl_traffic_rule_storage = NULL;
607 lck_mtx_unlock(&nxctl->nxctl_lock);
608
609 while (ntr != NULL) {
610 next = SLIST_NEXT(ntr, ntr_storage_link);
611 /*
612 * Clearing the flag to tell remove_traffic_rule() not to
613 * remove from the nxctl list again.
614 */
615 ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
616
617 /* Passing NULL because we already hold a reference */
618 err = remove_traffic_rule(nxctl, ntr->ntr_uuid, NULL);
619 if (err == 0) {
620 (void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
621 }
622 release_traffic_rule(ntr);
623 ntr = next;
624 }
625 }
626
627 SK_NO_INLINE_ATTRIBUTE
628 static void
add_traffic_rule_to_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)629 add_traffic_rule_to_nxctl(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
630 {
631 struct nxctl_traffic_rule_storage *rs;
632
633 lck_mtx_lock(&nxctl->nxctl_lock);
634 if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
635 rs = nxctl_traffic_rule_storage_create();
636 nxctl->nxctl_traffic_rule_storage = rs;
637 }
638 ntr->ntr_flags |= NTR_FLAG_ON_NXCTL_LIST;
639 retain_traffic_rule(ntr);
640 SLIST_INSERT_HEAD(&rs->rs_list, ntr, ntr_storage_link);
641 rs->rs_count++;
642 lck_mtx_unlock(&nxctl->nxctl_lock);
643 }
644
645 SK_NO_INLINE_ATTRIBUTE
646 static void
remove_traffic_rule_from_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)647 remove_traffic_rule_from_nxctl(struct nxctl *nxctl,
648 struct nxctl_traffic_rule *ntr)
649 {
650 struct nxctl_traffic_rule_storage *rs;
651
652 lck_mtx_lock(&nxctl->nxctl_lock);
653 if ((ntr->ntr_flags & NTR_FLAG_ON_NXCTL_LIST) == 0) {
654 lck_mtx_unlock(&nxctl->nxctl_lock);
655 return;
656 }
657 rs = nxctl->nxctl_traffic_rule_storage;
658 SLIST_REMOVE(&rs->rs_list, ntr, nxctl_traffic_rule, ntr_storage_link);
659 rs->rs_count--;
660 ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
661 release_traffic_rule(ntr);
662 if (rs->rs_count == 0) {
663 nxctl_traffic_rule_storage_destroy(rs);
664 nxctl->nxctl_traffic_rule_storage = NULL;
665 }
666 lck_mtx_unlock(&nxctl->nxctl_lock);
667 }
668
669 static int
inet_traffic_rule_validate(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)670 inet_traffic_rule_validate(struct nxctl_traffic_rule_type *type,
671 const char *ifname,
672 struct ifnet_traffic_descriptor_common *td,
673 struct ifnet_traffic_rule_action *ra)
674 {
675 #pragma unused(type)
676 char buf[IFNAMSIZ];
677 int unit, i;
678 struct ifnet_traffic_descriptor_inet *tdi;
679 uint8_t mask = 0, ipver, proto;
680
681 if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
682 SK_ERR("invalid ifname: %s", ifname);
683 return EINVAL;
684 }
685 if (td->itd_len != sizeof(*tdi)) {
686 SK_ERR("invalid td len: expected %d, actual %d",
687 sizeof(*tdi), td->itd_len);
688 return EINVAL;
689 }
690 if (td->itd_flags == 0 ||
691 (td->itd_flags &
692 ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
693 IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
694 SK_ERR("invalid td flags: 0x%x", td->itd_flags);
695 return EINVAL;
696 }
697 tdi = (struct ifnet_traffic_descriptor_inet *)td;
698 for (i = 0; i < NINETRULEMASKS; i++) {
699 if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
700 mask = tdi->inet_mask;
701 break;
702 }
703 }
704 if (mask == 0) {
705 SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
706 return EINVAL;
707 }
708 ipver = tdi->inet_ipver;
709 if (ipver != IPVERSION && ipver != IPV6_VERSION) {
710 SK_ERR("invalid inet ipver: 0x%x", ipver);
711 return EINVAL;
712 }
713 proto = tdi->inet_proto;
714 if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
715 SK_ERR("invalid inet proto: %d", proto);
716 return EINVAL;
717 }
718 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
719 if (ipver == IPVERSION) {
720 if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
721 SK_ERR("inet laddr v4 cannot be unspecified");
722 return EINVAL;
723 }
724 } else {
725 if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
726 SK_ERR("inet laddr v4 cannot be unspecified");
727 return EINVAL;
728 }
729 }
730 }
731 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
732 if (ipver == IPVERSION) {
733 if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
734 SK_ERR("inet raddr v6 cannot be unspecified");
735 return EINVAL;
736 }
737 } else {
738 if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
739 SK_ERR("inet raddr v6 cannot be unspecified");
740 return EINVAL;
741 }
742 }
743 }
744 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
745 if (tdi->inet_lport == 0) {
746 SK_ERR("inet lport cannot be unspecified");
747 return EINVAL;
748 }
749 }
750 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
751 if (tdi->inet_rport == 0) {
752 SK_ERR("inet rport cannot be unspecified");
753 return EINVAL;
754 }
755 }
756 if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
757 SK_ERR("invalid ra len: expected %d, actual %d",
758 sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
759 return EINVAL;
760 }
761 return 0;
762 }
763
764 SK_NO_INLINE_ATTRIBUTE
765 static struct nxctl_traffic_rule_inet_storage *
inet_traffic_rule_storage_create(void)766 inet_traffic_rule_storage_create(void)
767 {
768 struct nxctl_traffic_rule_inet_storage *rs;
769
770 rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
771 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
772 SLIST_INIT(&rs->ris_if_list);
773 rs->ris_count = 0;
774 return rs;
775 }
776
777 SK_NO_INLINE_ATTRIBUTE
778 static void
inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage * rs)779 inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage *rs)
780 {
781 ASSERT(rs->ris_count == 0);
782 ASSERT(SLIST_EMPTY(&rs->ris_if_list));
783 sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
784 }
785
786 SK_NO_INLINE_ATTRIBUTE
787 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)788 inet_traffic_rule_if_create(const char *ifname)
789 {
790 struct nxctl_traffic_rule_inet_if *rif;
791 int i;
792
793 rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
794 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
795 for (i = 0; i < NINETRULEMASKS; i++) {
796 SLIST_INIT(&rif->rii_lists[i]);
797 }
798 strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
799 rif->rii_count = 0;
800 return rif;
801 }
802
803 SK_NO_INLINE_ATTRIBUTE
804 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)805 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
806 {
807 int i;
808
809 for (i = 0; i < NINETRULEMASKS; i++) {
810 ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
811 }
812 ASSERT(rif->rii_count == 0);
813 sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
814 }
815
816 SK_NO_INLINE_ATTRIBUTE
817 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)818 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
819 uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
820 {
821 struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
822 struct ifnet_traffic_descriptor_inet *tdi0;
823 uint8_t mask;
824 boolean_t exact;
825
826 VERIFY(strlcmp(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname)) == 0);
827 tdi0 = &ntri->ntri_td;
828
829 exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
830 mask = tdi0->inet_mask & tdi->inet_mask;
831 if (exact) {
832 ASSERT(tdi0->inet_mask == tdi->inet_mask);
833 }
834 ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
835 if (tdi0->inet_ipver != tdi->inet_ipver) {
836 DTRACE_SKYWALK2(ipver__mismatch,
837 uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
838 return FALSE;
839 }
840 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0 &&
841 tdi0->inet_proto != tdi->inet_proto) {
842 DTRACE_SKYWALK2(proto__mismatch,
843 uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
844 return FALSE;
845 }
846 if (tdi0->inet_ipver == IPVERSION) {
847 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
848 tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
849 DTRACE_SKYWALK2(v4laddr__mismatch,
850 in_addr_t, tdi0->inet_laddr.iia_v4addr,
851 in_addr_t, tdi->inet_laddr.iia_v4addr);
852 return FALSE;
853 }
854 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
855 tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
856 DTRACE_SKYWALK2(v4raddr__mismatch,
857 in_addr_t, tdi0->inet_raddr.iia_v4addr,
858 in_addr_t, tdi->inet_raddr.iia_v4addr);
859 return FALSE;
860 }
861 } else {
862 ASSERT(tdi0->inet_ipver == IPV6_VERSION);
863 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
864 !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
865 DTRACE_SKYWALK2(v6laddr__mismatch,
866 struct in6_addr *, &tdi0->inet_laddr,
867 struct in6_addr *, &tdi->inet_laddr);
868 return FALSE;
869 }
870 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
871 !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
872 DTRACE_SKYWALK2(v6raddr__mismatch,
873 struct in6_addr *, &tdi0->inet_raddr,
874 struct in6_addr *, &tdi->inet_raddr);
875 return FALSE;
876 }
877 }
878 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
879 tdi0->inet_lport != tdi->inet_lport) {
880 DTRACE_SKYWALK2(lport__mismatch,
881 uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
882 return FALSE;
883 }
884 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
885 tdi0->inet_rport != tdi->inet_rport) {
886 DTRACE_SKYWALK2(rport__mismatch,
887 uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
888 return FALSE;
889 }
890 return TRUE;
891 }
892
893 static int
inet_traffic_rule_find(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)894 inet_traffic_rule_find(struct nxctl_traffic_rule_type *type, const char *ifname,
895 struct ifnet_traffic_descriptor_common *td, uint32_t flags,
896 struct nxctl_traffic_rule **ntrp)
897 {
898 struct nxctl_traffic_rule_inet *ntri = NULL;
899 struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
900 struct nxctl_traffic_rule_inet_if *rif;
901 struct ifnet_traffic_descriptor_inet *tdi =
902 (struct ifnet_traffic_descriptor_inet *)td;
903 int i;
904
905 if (rs == NULL) {
906 return ENOENT;
907 }
908 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
909 if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) != 0) {
910 continue;
911 }
912 for (i = 0; i < NINETRULEMASKS; i++) {
913 if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
914 tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
915 continue;
916 }
917 SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
918 if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
919 *ntrp = (struct nxctl_traffic_rule *)ntri;
920 return 0;
921 }
922 }
923 }
924 }
925 return ENOENT;
926 }
927
928 static int
inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type * type,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)929 inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type *type,
930 uuid_t uuid, struct nxctl_traffic_rule **ntrp)
931 {
932 struct nxctl_traffic_rule_inet *ntri;
933 struct nxctl_traffic_rule *ntr;
934 struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
935 struct nxctl_traffic_rule_inet_if *rif;
936 int i;
937
938 if (rs == NULL) {
939 return ENOENT;
940 }
941 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
942 for (i = 0; i < NINETRULEMASKS; i++) {
943 SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
944 ntr = &ntri->ntri_common;
945 if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
946 *ntrp = ntr;
947 return 0;
948 }
949 }
950 }
951 }
952 return ENOENT;
953 }
954
955 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)956 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
957 {
958 struct ifnet *ifp;
959
960 ifp = ifunit_ref(ifname);
961 if (ifp == NULL) {
962 DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
963 return;
964 }
965 ifnet_update_traffic_rule_count(ifp, count);
966 ifnet_decr_iorefcnt(ifp);
967 }
968
969 static void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)970 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
971 {
972 struct nxctl_traffic_rule_type *type = ntr->ntr_type;
973 struct nxctl_traffic_rule_inet_storage *rs;
974 struct nxctl_traffic_rule_inet_if *rif;
975 struct nxctl_traffic_rule_inet *ntri =
976 (struct nxctl_traffic_rule_inet *)ntr;
977 struct nxctl_traffic_rule_inet_head *list = NULL;
978 int i;
979 char *__null_terminated ntr_ifname = NULL;
980 char *__null_terminated rii_ifname = NULL;
981
982 if ((rs = type->ntrt_storage) == NULL) {
983 rs = inet_traffic_rule_storage_create();
984 type->ntrt_storage = rs;
985 }
986 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
987 if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
988 break;
989 }
990 }
991 if (rif == NULL) {
992 ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
993 rif = inet_traffic_rule_if_create(ntr_ifname);
994 SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
995 }
996 for (i = 0; i < NINETRULEMASKS; i++) {
997 if (ntri->ntri_td.inet_mask ==
998 nxctl_inet_traffic_rule_masks[i]) {
999 list = &rif->rii_lists[i];
1000 break;
1001 }
1002 }
1003 retain_traffic_rule(ntr);
1004 ASSERT(list != NULL);
1005 SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
1006 /* per-interface count */
1007 rif->rii_count++;
1008 rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
1009 inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
1010
1011 /* global count */
1012 rs->ris_count++;
1013 }
1014
1015 static void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)1016 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
1017 {
1018 struct nxctl_traffic_rule_inet_storage *rs;
1019 struct nxctl_traffic_rule_inet_if *rif;
1020 struct nxctl_traffic_rule_inet *ntri =
1021 (struct nxctl_traffic_rule_inet *)ntr;
1022 struct nxctl_traffic_rule_inet_head *list = NULL;
1023 struct nxctl_traffic_rule_type *type;
1024 int i;
1025 char *__null_terminated rii_ifname = NULL;
1026
1027 type = ntr->ntr_type;
1028 rs = type->ntrt_storage;
1029 ASSERT(rs != NULL);
1030 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1031 if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
1032 break;
1033 }
1034 }
1035 ASSERT(rif != NULL);
1036 for (i = 0; i < NINETRULEMASKS; i++) {
1037 if (ntri->ntri_td.inet_mask ==
1038 nxctl_inet_traffic_rule_masks[i]) {
1039 list = &rif->rii_lists[i];
1040 break;
1041 }
1042 }
1043 ASSERT(list != NULL);
1044 SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
1045 rif->rii_count--;
1046 rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
1047 inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
1048
1049 rs->ris_count--;
1050 release_traffic_rule(ntr);
1051
1052 if (rif->rii_count == 0) {
1053 SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
1054 inet_traffic_rule_if_destroy(rif);
1055 }
1056 if (rs->ris_count == 0) {
1057 type->ntrt_storage = NULL;
1058 inet_traffic_rule_storage_destroy(rs);
1059 }
1060 }
1061
1062 /*
1063 * XXX
1064 * This may need additional changes to ensure safety against detach/attach.
1065 * This is not an issue for the first consumer of llink interfaces, cellular,
1066 * which does not detach.
1067 */
1068 static int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)1069 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
1070 {
1071 struct ifnet *ifp;
1072 struct nx_netif *nif;
1073 struct netif_qset *__single qset = NULL;
1074 struct nxctl_traffic_rule_inet *ntri;
1075 int err = 0;
1076 char *__null_terminated ntr_ifname = NULL;
1077
1078 ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
1079 ifp = ifunit_ref(ntr_ifname);
1080 if (ifp == NULL) {
1081 DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
1082 err = ENXIO;
1083 goto done;
1084 }
1085 nif = NA(ifp)->nifna_netif;
1086 if (!NX_LLINK_PROV(nif->nif_nx)) {
1087 DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
1088 err = ENOTSUP;
1089 goto done;
1090 }
1091 ntri = (struct nxctl_traffic_rule_inet *)ntr;
1092 qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
1093 err = nx_netif_notify_steering_info(nif, qset,
1094 (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
1095 ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
1096 done:
1097 if (qset != NULL) {
1098 nx_netif_qset_release(&qset);
1099 }
1100 if (ifp != NULL) {
1101 ifnet_decr_iorefcnt(ifp);
1102 }
1103 return err;
1104 }
1105
1106 static int
inet_traffic_rule_create(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1107 inet_traffic_rule_create(struct nxctl_traffic_rule_type *type,
1108 const char *ifname, struct ifnet_traffic_descriptor_common *td,
1109 struct ifnet_traffic_rule_action *ra, uint32_t flags,
1110 struct nxctl_traffic_rule **ntrp)
1111 {
1112 struct nxctl_traffic_rule_inet *ntri;
1113 struct nxctl_traffic_rule *ntr;
1114 struct ifnet_traffic_descriptor_inet *tdi;
1115 struct ifnet_traffic_rule_action_steer *ras;
1116
1117 ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
1118 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
1119 ntr = &ntri->ntri_common;
1120
1121 ntr->ntr_type = type;
1122 ntr->ntr_flags = flags;
1123 uuid_generate(ntr->ntr_uuid);
1124 os_ref_init(&ntr->ntr_refcnt, NULL);
1125
1126 strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
1127 proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
1128
1129 tdi = __container_of(td, struct ifnet_traffic_descriptor_inet, inet_common);
1130 ras = __container_of(ra, struct ifnet_traffic_rule_action_steer, ras_common);
1131 bcopy(tdi, &ntri->ntri_td, sizeof(ntri->ntri_td));
1132 bcopy(ras, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
1133
1134 *ntrp = ntr;
1135 return 0;
1136 }
1137
1138 static void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)1139 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
1140 {
1141 struct nxctl_traffic_rule_inet *ntri;
1142
1143 ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
1144 ntri = (struct nxctl_traffic_rule_inet *)ntr;
1145 sk_free_type(struct nxctl_traffic_rule_inet, ntri);
1146 }
1147
1148 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)1149 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
1150 struct nxctl_traffic_rule_inet_iocinfo *info)
1151 {
1152 struct nxctl_traffic_rule *ntr;
1153 struct nxctl_traffic_rule_generic_iocinfo *ginfo;
1154
1155 bzero(info, sizeof(*info));
1156 ntr = &ntri->ntri_common;
1157 ginfo = &info->tri_common;
1158 _CASSERT(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
1159 _CASSERT(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
1160 uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
1161 strbufcpy(ginfo->trg_procname, ntr->ntr_procname);
1162 strbufcpy(ginfo->trg_ifname, ntr->ntr_ifname);
1163 bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
1164 bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
1165 }
1166
1167 static int
inet_traffic_rule_get_all(struct nxctl_traffic_rule_type * type,uint32_t size,uint32_t * count,user_addr_t uaddr)1168 inet_traffic_rule_get_all(struct nxctl_traffic_rule_type *type, uint32_t size,
1169 uint32_t *count, user_addr_t uaddr)
1170 {
1171 struct nxctl_traffic_rule_inet *ntri = NULL;
1172 struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
1173 struct nxctl_traffic_rule_inet_if *rif;
1174 struct nxctl_traffic_rule_inet_iocinfo info;
1175 int i, err;
1176
1177 if (size != sizeof(info)) {
1178 SK_ERR("size: actual %d, expected %d", size, sizeof(info));
1179 return EINVAL;
1180 }
1181 if (rs == NULL) {
1182 *count = 0;
1183 return 0;
1184 }
1185 if (*count < rs->ris_count) {
1186 SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
1187 return ENOBUFS;
1188 }
1189 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1190 for (i = 0; i < NINETRULEMASKS; i++) {
1191 SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
1192 convert_ntri_to_iocinfo(ntri, &info);
1193 err = copyout(&info, uaddr, sizeof(info));
1194 if (err != 0) {
1195 SK_ERR("copyout failed: %d", err);
1196 return err;
1197 }
1198 uaddr += sizeof(info);
1199 }
1200 }
1201 }
1202 *count = rs->ris_count;
1203 return 0;
1204 }
1205
1206 SK_NO_INLINE_ATTRIBUTE
1207 static void
retain_traffic_rule(struct nxctl_traffic_rule * ntr)1208 retain_traffic_rule(struct nxctl_traffic_rule *ntr)
1209 {
1210 #if (DEVELOPMENT || DEBUG)
1211 os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1212 DTRACE_SKYWALK2(ntr__retain, struct nxctl_traffic_rule *, ntr,
1213 os_ref_count_t, count);
1214 #endif
1215 os_ref_retain(&ntr->ntr_refcnt);
1216 }
1217
1218 SK_NO_INLINE_ATTRIBUTE
1219 static void
release_traffic_rule(struct nxctl_traffic_rule * ntr)1220 release_traffic_rule(struct nxctl_traffic_rule *ntr)
1221 {
1222 #if (DEVELOPMENT || DEBUG)
1223 os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1224 DTRACE_SKYWALK2(ntr__release, struct nxctl_traffic_rule *, ntr,
1225 os_ref_count_t, count);
1226 #endif
1227 if (os_ref_release(&ntr->ntr_refcnt) == 0) {
1228 ntr->ntr_type->ntrt_destroy(ntr);
1229 }
1230 }
1231
1232 SK_NO_INLINE_ATTRIBUTE
1233 static int
notify_traffic_rule(struct nxctl_traffic_rule * ntr,uint32_t flags)1234 notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags)
1235 {
1236 return ntr->ntr_type->ntrt_notify(ntr, flags);
1237 }
1238
1239 static void
link_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1240 link_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1241 {
1242 /*
1243 * The persist flag means: do not clean up rule upon nxctl fd close.
1244 * This means we only add the rule to the nxctl list if persist
1245 * is not set.
1246 */
1247 if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1248 add_traffic_rule_to_nxctl(nxctl, ntr);
1249 }
1250 ntr->ntr_type->ntrt_link(ntr);
1251 }
1252
1253 static void
unlink_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1254 unlink_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1255 {
1256 if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1257 remove_traffic_rule_from_nxctl(nxctl, ntr);
1258 }
1259 ntr->ntr_type->ntrt_unlink(ntr);
1260 }
1261
1262 static int
find_traffic_rule_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1263 find_traffic_rule_by_uuid(uuid_t uuid, struct nxctl_traffic_rule **ntrp)
1264 {
1265 int i, err;
1266 struct nxctl_traffic_rule_type *ntrt;
1267 struct nxctl_traffic_rule *__single ntr = NULL;
1268
1269 for (i = 0; i < NRULETYPES; i++) {
1270 ntrt = &nxctl_rule_types[i];
1271 err = ntrt->ntrt_find_by_uuid(ntrt, uuid, &ntr);
1272 if (err == 0) {
1273 ASSERT(ntr != NULL);
1274 *ntrp = ntr;
1275 return 0;
1276 }
1277 }
1278 return ENOENT;
1279 }
1280
1281 static struct nxctl_traffic_rule_type *
find_traffic_rule_type(uint8_t type)1282 find_traffic_rule_type(uint8_t type)
1283 {
1284 int i;
1285 struct nxctl_traffic_rule_type *ntrt;
1286
1287 for (i = 0; i < NRULETYPES; i++) {
1288 ntrt = &nxctl_rule_types[i];
1289 if (ntrt->ntrt_type == type) {
1290 return ntrt;
1291 }
1292 }
1293 return NULL;
1294 }
1295
1296 SK_NO_INLINE_ATTRIBUTE
1297 static int
add_traffic_rule(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1298 add_traffic_rule(struct nxctl *nxctl, const char *ifname,
1299 struct ifnet_traffic_descriptor_common *td,
1300 struct ifnet_traffic_rule_action *ra,
1301 uint32_t flags,
1302 struct nxctl_traffic_rule **ntrp)
1303 {
1304 struct nxctl_traffic_rule_type *type = NULL;
1305 struct nxctl_traffic_rule *__single ntr = NULL;
1306 int err;
1307
1308 NXTR_WLOCK();
1309 type = find_traffic_rule_type(td->itd_type);
1310 if (type == NULL) {
1311 SK_ERR("rule type %x not found", td->itd_type);
1312 err = EINVAL;
1313 goto fail;
1314 }
1315 err = type->ntrt_validate(type, ifname, td, ra);
1316 if (err != 0) {
1317 SK_ERR("rule validate failed: %d", err);
1318 goto fail;
1319 }
1320 err = type->ntrt_find(type, ifname, td, NTR_FIND_FLAG_EXACT, &ntr);
1321 if (err == 0) {
1322 SK_ERR("rule already exists");
1323 ASSERT(ntr != NULL);
1324 err = EEXIST;
1325 goto fail;
1326 } else if (err != ENOENT) {
1327 SK_ERR("rule find failed: %d", err);
1328 goto fail;
1329 }
1330 err = type->ntrt_create(type, ifname, td, ra, flags, &ntr);
1331 if (err != 0) {
1332 SK_ERR("rule create failed: %d", err);
1333 goto fail;
1334 }
1335 link_traffic_rule(nxctl, ntr);
1336 if (ntrp != NULL) {
1337 retain_traffic_rule(ntr);
1338 *ntrp = ntr;
1339 }
1340 NXTR_WUNLOCK();
1341 return 0;
1342 fail:
1343 NXTR_WUNLOCK();
1344 return err;
1345 }
1346
1347
1348 SK_NO_INLINE_ATTRIBUTE
1349 static int
remove_traffic_rule(struct nxctl * nxctl,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1350 remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
1351 struct nxctl_traffic_rule **ntrp)
1352 {
1353 struct nxctl_traffic_rule *__single ntr;
1354 int err;
1355
1356 NXTR_WLOCK();
1357 err = find_traffic_rule_by_uuid(uuid, &ntr);
1358 if (err != 0) {
1359 SK_ERR("traffic rule not found");
1360 NXTR_WUNLOCK();
1361 return err;
1362 }
1363 if (ntrp != NULL) {
1364 retain_traffic_rule(ntr);
1365 *ntrp = ntr;
1366 }
1367 unlink_traffic_rule(nxctl, ntr);
1368 /* release initial reference */
1369 release_traffic_rule(ntr);
1370 NXTR_WUNLOCK();
1371 return 0;
1372 }
1373
1374 static uint32_t
convert_traffic_rule_ioc_flags(uint32_t flags)1375 convert_traffic_rule_ioc_flags(uint32_t flags)
1376 {
1377 uint32_t f = 0;
1378
1379 if ((flags & NXIOC_ADD_TRAFFIC_RULE_FLAG_PERSIST) != 0) {
1380 f |= NTR_FLAG_PERSIST;
1381 }
1382 return f;
1383 }
1384
1385 SK_NO_INLINE_ATTRIBUTE
1386 static int
add_traffic_rule_generic(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,uuid_t * uuid)1387 add_traffic_rule_generic(struct nxctl *nxctl, const char *ifname,
1388 struct ifnet_traffic_descriptor_common *td,
1389 struct ifnet_traffic_rule_action *ra, uint32_t flags, uuid_t *uuid)
1390 {
1391 struct nxctl_traffic_rule *__single ntr;
1392 int err;
1393
1394 err = add_traffic_rule(nxctl, ifname, td, ra, flags, &ntr);
1395 if (err != 0) {
1396 return err;
1397 }
1398 (void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_ADD);
1399 uuid_copy(*uuid, ntr->ntr_uuid);
1400 release_traffic_rule(ntr);
1401 return 0;
1402 }
1403
1404 int
nxioctl_add_traffic_rule_inet(struct nxctl * nxctl,caddr_t data,proc_t procp)1405 nxioctl_add_traffic_rule_inet(struct nxctl *nxctl, caddr_t data, proc_t procp)
1406 {
1407 #pragma unused(procp)
1408 struct nxctl_add_traffic_rule_inet_iocargs *args =
1409 (struct nxctl_add_traffic_rule_inet_iocargs *)(void *)data;
1410 char *__null_terminated atri_ifname = NULL;
1411
1412 atri_ifname = __unsafe_null_terminated_from_indexable(args->atri_ifname);
1413
1414 return add_traffic_rule_generic(nxctl, atri_ifname,
1415 (struct ifnet_traffic_descriptor_common *)&args->atri_td,
1416 (struct ifnet_traffic_rule_action *)&args->atri_ra,
1417 convert_traffic_rule_ioc_flags(args->atri_flags),
1418 &args->atri_uuid);
1419 }
1420
1421 int
nxioctl_remove_traffic_rule(struct nxctl * nxctl,caddr_t data,proc_t procp)1422 nxioctl_remove_traffic_rule(struct nxctl *nxctl, caddr_t data, proc_t procp)
1423 {
1424 #pragma unused(procp)
1425 struct nxctl_remove_traffic_rule_iocargs *args =
1426 (struct nxctl_remove_traffic_rule_iocargs *)(void *)data;
1427 struct nxctl_traffic_rule *__single ntr;
1428 int err;
1429
1430 err = remove_traffic_rule(nxctl, args->rtr_uuid, &ntr);
1431 if (err != 0) {
1432 return err;
1433 }
1434 (void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
1435 release_traffic_rule(ntr);
1436 return 0;
1437 }
1438
1439 int
nxioctl_get_traffic_rules(struct nxctl * nxctl,caddr_t data,proc_t procp)1440 nxioctl_get_traffic_rules(struct nxctl *nxctl, caddr_t data, proc_t procp)
1441 {
1442 #pragma unused(nxctl)
1443 struct nxctl_get_traffic_rules_iocargs *args =
1444 (struct nxctl_get_traffic_rules_iocargs *)(void *)data;
1445 struct nxctl_traffic_rule_type *type;
1446 user_addr_t uaddr;
1447 int err;
1448
1449 NXTR_RLOCK();
1450 type = find_traffic_rule_type(args->gtr_type);
1451 if (type == NULL) {
1452 SK_ERR("rule type %x not found", args->gtr_type);
1453 err = EINVAL;
1454 goto fail;
1455 }
1456 uaddr = proc_is64bit(procp) ? args->gtr_buf64 :
1457 CAST_USER_ADDR_T(args->gtr_buf);
1458 err = type->ntrt_get_all(type, args->gtr_size, &args->gtr_count, uaddr);
1459 if (err != 0) {
1460 goto fail;
1461 }
1462 NXTR_RUNLOCK();
1463 return 0;
1464 fail:
1465 NXTR_RUNLOCK();
1466 return err;
1467 }
1468