1 /*
2 * Copyright (c) 2022 Apple Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28 #include <skywalk/os_skywalk_private.h>
29 #include <skywalk/nexus/netif/nx_netif.h>
30 #include <netinet/ip.h>
31 #include <netinet/ip6.h>
32 #include <netinet/tcp.h>
33 #include <netinet/udp.h>
34 #include <sys/sdt.h>
35
36 /*
37 * Implementation of nexus traffic rules APIs.
38 */
39
40 struct nxctl_traffic_rule_type;
41 struct nxctl_traffic_rule;
42
43 /*
44 * These callbacks need to be implemented for each rule type.
45 */
46
47 /* Validate user provided parameters. */
48 typedef int (nxctl_traffic_rule_validate_cb_t)(
49 struct nxctl_traffic_rule_type *type,
50 const char *ifname,
51 struct ifnet_traffic_descriptor_common *td,
52 struct ifnet_traffic_rule_action *ra);
53 /*
54 * Each rule type has its own global structure for storing rules.
55 * These callbacks access this global structure.
56 */
57 #define NTR_FIND_FLAG_EXACT 0x0001
58 typedef int (nxctl_traffic_rule_find_cb_t)(
59 struct nxctl_traffic_rule_type *type,
60 const char *ifname,
61 struct ifnet_traffic_descriptor_common *td,
62 uint32_t flags,
63 struct nxctl_traffic_rule **ntrp);
64
65 typedef int (nxctl_traffic_rule_find_by_uuid_cb_t)(
66 struct nxctl_traffic_rule_type *type,
67 uuid_t uuid,
68 struct nxctl_traffic_rule **ntrp);
69
70 typedef void (nxctl_traffic_rule_link_cb_t)(
71 struct nxctl_traffic_rule *ntr);
72
73 typedef void (nxctl_traffic_rule_unlink_cb_t)(
74 struct nxctl_traffic_rule *ntr);
75
76 /*
77 * Notifies lower layers of the addition/removal of a rule.
78 * This is called outside of nxctl_traffic_rule_lock to avoid potential
79 * locking issues.
80 */
81 #define NTR_NOTIFY_FLAG_ADD 0x0001
82 #define NTR_NOTIFY_FLAG_REMOVE 0x0002
83 typedef int (nxctl_traffic_rule_notify_cb_t)(
84 struct nxctl_traffic_rule *ntr,
85 uint32_t flags);
86
87 /*
88 * Create/Destroy callbacks for a rule type.
89 */
90 typedef int (nxctl_traffic_rule_create_cb_t)(
91 struct nxctl_traffic_rule_type *type,
92 const char *ifname,
93 struct ifnet_traffic_descriptor_common *td,
94 struct ifnet_traffic_rule_action *ra,
95 uint32_t flags,
96 struct nxctl_traffic_rule **ntrp);
97
98 typedef void (nxctl_traffic_rule_destroy_cb_t)(
99 struct nxctl_traffic_rule *ntr);
100
101 /*
102 * This is used for copying all rules for a type (including generic
103 * and type-specific info) to userspace.
104 */
105 typedef int (nxctl_traffic_rule_get_all_cb_t)(
106 struct nxctl_traffic_rule_type *type,
107 uint32_t size,
108 uint32_t *count,
109 user_addr_t uaddr);
110
111 struct nxctl_traffic_rule_type {
112 uint8_t ntrt_type;
113 nxctl_traffic_rule_validate_cb_t *ntrt_validate;
114 nxctl_traffic_rule_find_cb_t *ntrt_find;
115 nxctl_traffic_rule_find_by_uuid_cb_t *ntrt_find_by_uuid;
116 nxctl_traffic_rule_link_cb_t *ntrt_link;
117 nxctl_traffic_rule_unlink_cb_t *ntrt_unlink;
118 nxctl_traffic_rule_notify_cb_t *ntrt_notify;
119 nxctl_traffic_rule_create_cb_t *ntrt_create;
120 nxctl_traffic_rule_destroy_cb_t *ntrt_destroy;
121 nxctl_traffic_rule_get_all_cb_t *ntrt_get_all;
122 void *ntrt_storage;
123 };
124
125 static nxctl_traffic_rule_validate_cb_t inet_traffic_rule_validate;
126 static nxctl_traffic_rule_find_cb_t inet_traffic_rule_find;
127 static nxctl_traffic_rule_find_by_uuid_cb_t inet_traffic_rule_find_by_uuid;
128 static nxctl_traffic_rule_link_cb_t inet_traffic_rule_link;
129 static nxctl_traffic_rule_unlink_cb_t inet_traffic_rule_unlink;
130 static nxctl_traffic_rule_notify_cb_t inet_traffic_rule_notify;
131 static nxctl_traffic_rule_create_cb_t inet_traffic_rule_create;
132 static nxctl_traffic_rule_destroy_cb_t inet_traffic_rule_destroy;
133 static nxctl_traffic_rule_get_all_cb_t inet_traffic_rule_get_all;
134
135 static struct nxctl_traffic_rule_type nxctl_rule_types[] = {
136 {
137 .ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET,
138 .ntrt_validate = inet_traffic_rule_validate,
139 .ntrt_find = inet_traffic_rule_find,
140 .ntrt_find_by_uuid = inet_traffic_rule_find_by_uuid,
141 .ntrt_link = inet_traffic_rule_link,
142 .ntrt_unlink = inet_traffic_rule_unlink,
143 .ntrt_notify = inet_traffic_rule_notify,
144 .ntrt_create = inet_traffic_rule_create,
145 .ntrt_destroy = inet_traffic_rule_destroy,
146 .ntrt_get_all = inet_traffic_rule_get_all,
147 },
148 };
149 #define NRULETYPES \
150 (sizeof(nxctl_rule_types)/sizeof(struct nxctl_traffic_rule_type))
151
152 /*
153 * Generic traffic rule.
154 * Contains fields common to all traffic rules.
155 */
156 #define NTR_FLAG_PERSIST 0x0001
157 #define NTR_FLAG_ON_NXCTL_LIST 0x0002
158 struct nxctl_traffic_rule {
159 struct nxctl_traffic_rule_type *ntr_type;
160 uint32_t ntr_flags;
161 os_refcnt_t ntr_refcnt;
162 uuid_t ntr_uuid;
163 char ntr_procname[NTR_PROCNAME_SZ];
164 char ntr_ifname[IFNAMSIZ];
165 SLIST_ENTRY(nxctl_traffic_rule) ntr_storage_link;
166 };
167
168 /*
169 * Inet-specific traffic rule.
170 */
171 struct nxctl_traffic_rule_inet {
172 struct nxctl_traffic_rule ntri_common;
173 SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
174 struct ifnet_traffic_descriptor_inet ntri_td;
175 struct ifnet_traffic_rule_action_steer ntri_ra;
176 };
177
178 /*
179 * Currently supported tuple types.
180 */
181 static uint8_t nxctl_inet_traffic_rule_masks[] = {
182 (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
183 IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
184 IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
185 IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
186 IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
187 IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT),
188
189 (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
190 IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
191 IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
192 IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT),
193 };
194 #define NINETRULEMASKS \
195 (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
196
197 /* Per-interface lists of traffic rules */
198 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
199 struct nxctl_traffic_rule_inet_if {
200 char rii_ifname[IFNAMSIZ];
201 struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
202 uint32_t rii_count;
203 SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
204 };
205
206 /* List of per-interface lists */
207 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
208 struct nxctl_traffic_rule_inet_storage {
209 struct nxctl_traffic_rule_inet_if_head ris_if_list;
210 uint32_t ris_count;
211 };
212
213 /* Per-fd list kept at the nxctl */
214 SLIST_HEAD(nxctl_traffic_rule_head, nxctl_traffic_rule);
215 struct nxctl_traffic_rule_storage {
216 struct nxctl_traffic_rule_head rs_list;
217 uint32_t rs_count;
218 };
219
220 static LCK_RW_DECLARE_ATTR(nxctl_traffic_rule_lock, &sk_lock_group, &sk_lock_attr);
221 #define NXTR_WLOCK() \
222 lck_rw_lock_exclusive(&nxctl_traffic_rule_lock)
223 #define NXTR_WUNLOCK() \
224 lck_rw_unlock_exclusive(&nxctl_traffic_rule_lock)
225 #define NXTR_RLOCK() \
226 lck_rw_lock_shared(&nxctl_traffic_rule_lock)
227 #define NXTR_RUNLOCK() \
228 lck_rw_unlock_shared(&nxctl_traffic_rule_lock)
229
230 static struct nxctl_traffic_rule_type *find_traffic_rule_type(uint8_t type);
231 static void retain_traffic_rule(struct nxctl_traffic_rule *ntr);
232 static void release_traffic_rule(struct nxctl_traffic_rule *ntr);
233 static int remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
234 struct nxctl_traffic_rule **ntrp);
235 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
236 struct ifnet_ip_addr *a2);
237 static int notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags);
238
239 #define NXCTL_TRAFFIC_RULE_TAG "com.apple.skywalk.nexus.traffic_rule"
240 static kern_allocation_name_t nxctl_traffic_rule_tag;
241 static struct nxctl_traffic_rule_type *inet_traffic_rule_type = NULL;
242
243 /*
244 * If a interface attaches after rule(s) are added, this function is used
245 * retrieve the current rule count for that interface.
246 */
247 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)248 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
249 {
250 struct nxctl_traffic_rule_inet_storage *rs;
251 struct nxctl_traffic_rule_inet_if *rif;
252 int err;
253
254 NXTR_RLOCK();
255 rs = inet_traffic_rule_type->ntrt_storage;
256 if (rs == NULL) {
257 err = ENOENT;
258 goto fail;
259 }
260 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
261 if (strcmp(rif->rii_ifname, ifname) == 0) {
262 break;
263 }
264 }
265 if (rif == NULL) {
266 err = ENOENT;
267 goto fail;
268 }
269 *count = rif->rii_count;
270 NXTR_RUNLOCK();
271 return 0;
272 fail:
273 NXTR_RUNLOCK();
274 return err;
275 }
276
277 /*
278 * Used for finding the qset id associated with a traffic descriptor.
279 */
280 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)281 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
282 struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
283 {
284 struct nxctl_traffic_rule_inet *ntri = NULL;
285 int err;
286
287 NXTR_RLOCK();
288 ASSERT(inet_traffic_rule_type != NULL);
289 err = inet_traffic_rule_type->ntrt_find(inet_traffic_rule_type, ifname,
290 (struct ifnet_traffic_descriptor_common *)td, 0,
291 (struct nxctl_traffic_rule **)&ntri);
292 if (err != 0) {
293 SK_ERR("rule find failed: %d", err);
294 goto fail;
295 }
296 *qset_id = ntri->ntri_ra.ras_qset_id;
297 NXTR_RUNLOCK();
298 return 0;
299 fail:
300 NXTR_RUNLOCK();
301 return err;
302 }
303
304 /*
305 * Based on flow_pkt_classify().
306 * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
307 */
308 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)309 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
310 {
311 union {
312 volatile struct ip *_iph;
313 volatile struct ip6_hdr *_ip6;
314 } _l3;
315 #define iph _l3._iph
316 #define ip6 _l3._ip6
317 union {
318 volatile struct tcphdr *_tcph;
319 volatile struct udphdr *_udph;
320 } _l4;
321 #define tcph _l4._tcph
322 #define udph _l4._udph
323 uint8_t *pkt_buf, *l3_hdr;
324 uint16_t bdlen, bdlim, bdoff, cls_len;
325 size_t pkt_len;
326 uint8_t ipv, l3hlen = 0; /* IP header length */
327 uint16_t l3tlen = 0; /* total length of IP packet */
328 uint8_t l4hlen = 0; /* TCP/UDP header length */
329 uint16_t ulen = 0; /* user data length */
330 int err;
331
332 ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
333 pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
334
335 MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
336 cls_len = bdlim - bdoff;
337 cls_len -= pkt->pkt_l2_len;
338 cls_len = (uint16_t)MIN(cls_len, pkt_len);
339 VERIFY(pkt_len >= cls_len);
340 if (cls_len == 0) {
341 SK_ERR("cls_len == 0");
342 err = EINVAL;
343 goto fail;
344 }
345 l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
346 iph = (volatile struct ip *)(void *)l3_hdr;
347 ipv = iph->ip_v;
348
349 switch (ipv) {
350 case 4:
351 if (cls_len < sizeof(struct ip)) {
352 SK_ERR("cls_len < sizeof(struct ip) (%d < %d)",
353 cls_len, sizeof(struct ip));
354 err = EINVAL;
355 goto fail;
356 }
357 l3hlen = (uint8_t)(iph->ip_hl << 2);
358 if (l3hlen < sizeof(struct ip)) {
359 SK_ERR("l3hlen < sizeof(struct ip) (%d < %d)",
360 l3hlen, sizeof(struct ip));
361 err = EINVAL;
362 goto fail;
363 }
364 if (cls_len < l3hlen) {
365 SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
366 err = EINVAL;
367 goto fail;
368 }
369 l3tlen = ntohs(iph->ip_len);
370 if (l3tlen < l3hlen) {
371 SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
372 err = EINVAL;
373 goto fail;
374 }
375 if (pkt_len < l3tlen) {
376 SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
377 err = EINVAL;
378 goto fail;
379 }
380 td->inet_ipver = IPVERSION;
381 td->inet_proto = iph->ip_p;
382 bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
383 sizeof(iph->ip_src));
384 bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
385 sizeof(iph->ip_dst));
386 break;
387 case 6:
388 l3hlen = sizeof(struct ip6_hdr);
389 if (cls_len < l3hlen) {
390 SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
391 err = EINVAL;
392 goto fail;
393 }
394 l3tlen = l3hlen + ntohs(ip6->ip6_plen);
395 if (pkt_len < l3tlen) {
396 SK_ERR("pkt_len < l3tlen (%d < %d)", pkt_len, l3tlen);
397 err = EINVAL;
398 goto fail;
399 }
400 td->inet_ipver = IPV6_VERSION;
401 td->inet_proto = ip6->ip6_nxt;
402 bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
403 sizeof(ip6->ip6_src));
404 bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
405 sizeof(ip6->ip6_dst));
406 break;
407 default:
408 SK_ERR("ipv == %d", ipv);
409 err = EINVAL;
410 goto fail;
411 }
412 tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
413 ulen = (l3tlen - l3hlen);
414 if (td->inet_proto == IPPROTO_TCP) {
415 if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
416 SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
417 "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*tcph),
418 ulen, sizeof(*tcph));
419 err = EINVAL;
420 goto fail;
421 }
422 l4hlen = (uint8_t)(tcph->th_off << 2);
423 if (l4hlen < sizeof(*tcph)) {
424 SK_ERR("l4hlen < sizeof(*tcph) (%d < %d)", l4hlen, sizeof(*tcph));
425 err = EINVAL;
426 goto fail;
427 }
428 if (l4hlen > ulen) {
429 SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
430 err = EINVAL;
431 goto fail;
432 }
433 bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
434 sizeof(td->inet_lport));
435 bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
436 sizeof(td->inet_rport));
437 } else if (td->inet_proto == IPPROTO_UDP) {
438 if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
439 SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
440 "(%d < %d + %d || %d < %d)", cls_len, l3hlen, sizeof(*udph),
441 ulen, sizeof(*udph));
442 err = EINVAL;
443 goto fail;
444 }
445 l4hlen = sizeof(*udph);
446 if (l4hlen > ulen) {
447 SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
448 err = EINVAL;
449 goto fail;
450 }
451 bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
452 sizeof(td->inet_lport));
453 bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
454 sizeof(td->inet_rport));
455 } else {
456 err = ENOTSUP;
457 goto fail;
458 }
459
460 td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
461 td->inet_common.itd_len = sizeof(*td);
462 td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
463 IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
464 td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
465 IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
466 IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
467 IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
468 IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
469 IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
470 return 0;
471 fail:
472 DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
473 uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
474 int, err);
475 bzero(td, sizeof(*td));
476 return err;
477 #undef iph
478 #undef ip6
479 #undef tcph
480 #undef udph
481 }
482
483 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)484 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
485 struct __kern_packet *pkt, uint64_t *qset_id)
486 {
487 struct ifnet_traffic_descriptor_inet td;
488 int err;
489
490 err = fill_inet_td(pkt, &td);
491 if (err != 0) {
492 return err;
493 }
494 return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
495 }
496
497 void
nxctl_traffic_rule_init(void)498 nxctl_traffic_rule_init(void)
499 {
500 ASSERT(nxctl_traffic_rule_tag == NULL);
501 nxctl_traffic_rule_tag =
502 kern_allocation_name_allocate(NXCTL_TRAFFIC_RULE_TAG, 0);
503 ASSERT(nxctl_traffic_rule_tag != NULL);
504
505 ASSERT(inet_traffic_rule_type == NULL);
506 inet_traffic_rule_type =
507 find_traffic_rule_type(IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET);
508 ASSERT(inet_traffic_rule_type != NULL);
509 }
510
511 void
nxctl_traffic_rule_fini(void)512 nxctl_traffic_rule_fini(void)
513 {
514 if (nxctl_traffic_rule_tag != NULL) {
515 kern_allocation_name_release(nxctl_traffic_rule_tag);
516 nxctl_traffic_rule_tag = NULL;
517 }
518 inet_traffic_rule_type = NULL;
519 }
520
521 static struct ifnet_ip_addr v6_zeros_addr = {0};
522 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)523 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
524 {
525 return memcmp(a1, a2, sizeof(*a1)) == 0;
526 }
527
528 SK_NO_INLINE_ATTRIBUTE
529 static struct nxctl_traffic_rule_storage *
nxctl_traffic_rule_storage_create(void)530 nxctl_traffic_rule_storage_create(void)
531 {
532 struct nxctl_traffic_rule_storage *rs;
533
534 rs = sk_alloc_type(struct nxctl_traffic_rule_storage,
535 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
536 SLIST_INIT(&rs->rs_list);
537 rs->rs_count = 0;
538 return rs;
539 }
540
541 SK_NO_INLINE_ATTRIBUTE
542 static void
nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage * rs)543 nxctl_traffic_rule_storage_destroy(struct nxctl_traffic_rule_storage *rs)
544 {
545 ASSERT(rs->rs_count == 0);
546 ASSERT(SLIST_EMPTY(&rs->rs_list));
547 sk_free_type(struct nxctl_traffic_rule_storage, rs);
548 }
549
550 /*
551 * This is meant to be called during closure of the nxctl's fd.
552 * This will cleanup all rules linked to this nxctl. Rules that
553 * are marked persistent won't be added to the nxctl list.
554 */
555 void
nxctl_traffic_rule_clean(struct nxctl * nxctl)556 nxctl_traffic_rule_clean(struct nxctl *nxctl)
557 {
558 struct nxctl_traffic_rule_storage *rs;
559 struct nxctl_traffic_rule *ntr, *next;
560 int err;
561
562 lck_mtx_lock(&nxctl->nxctl_lock);
563 if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
564 lck_mtx_unlock(&nxctl->nxctl_lock);
565 return;
566 }
567 ntr = SLIST_FIRST(&rs->rs_list);
568 SLIST_INIT(&rs->rs_list);
569 rs->rs_count = 0;
570 nxctl_traffic_rule_storage_destroy(rs);
571 nxctl->nxctl_traffic_rule_storage = NULL;
572 lck_mtx_unlock(&nxctl->nxctl_lock);
573
574 while (ntr != NULL) {
575 next = SLIST_NEXT(ntr, ntr_storage_link);
576 /*
577 * Clearing the flag to tell remove_traffic_rule() not to
578 * remove from the nxctl list again.
579 */
580 ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
581
582 /* Passing NULL because we already hold a reference */
583 err = remove_traffic_rule(nxctl, ntr->ntr_uuid, NULL);
584 if (err == 0) {
585 (void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
586 }
587 release_traffic_rule(ntr);
588 ntr = next;
589 }
590 }
591
592 SK_NO_INLINE_ATTRIBUTE
593 static void
add_traffic_rule_to_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)594 add_traffic_rule_to_nxctl(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
595 {
596 struct nxctl_traffic_rule_storage *rs;
597
598 lck_mtx_lock(&nxctl->nxctl_lock);
599 if ((rs = nxctl->nxctl_traffic_rule_storage) == NULL) {
600 rs = nxctl_traffic_rule_storage_create();
601 nxctl->nxctl_traffic_rule_storage = rs;
602 }
603 ntr->ntr_flags |= NTR_FLAG_ON_NXCTL_LIST;
604 retain_traffic_rule(ntr);
605 SLIST_INSERT_HEAD(&rs->rs_list, ntr, ntr_storage_link);
606 rs->rs_count++;
607 lck_mtx_unlock(&nxctl->nxctl_lock);
608 }
609
610 SK_NO_INLINE_ATTRIBUTE
611 static void
remove_traffic_rule_from_nxctl(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)612 remove_traffic_rule_from_nxctl(struct nxctl *nxctl,
613 struct nxctl_traffic_rule *ntr)
614 {
615 struct nxctl_traffic_rule_storage *rs;
616
617 lck_mtx_lock(&nxctl->nxctl_lock);
618 if ((ntr->ntr_flags & NTR_FLAG_ON_NXCTL_LIST) == 0) {
619 lck_mtx_unlock(&nxctl->nxctl_lock);
620 return;
621 }
622 rs = nxctl->nxctl_traffic_rule_storage;
623 SLIST_REMOVE(&rs->rs_list, ntr, nxctl_traffic_rule, ntr_storage_link);
624 rs->rs_count--;
625 ntr->ntr_flags &= ~NTR_FLAG_ON_NXCTL_LIST;
626 release_traffic_rule(ntr);
627 if (rs->rs_count == 0) {
628 nxctl_traffic_rule_storage_destroy(rs);
629 nxctl->nxctl_traffic_rule_storage = NULL;
630 }
631 lck_mtx_unlock(&nxctl->nxctl_lock);
632 }
633
634 static int
inet_traffic_rule_validate(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)635 inet_traffic_rule_validate(struct nxctl_traffic_rule_type *type,
636 const char *ifname,
637 struct ifnet_traffic_descriptor_common *td,
638 struct ifnet_traffic_rule_action *ra)
639 {
640 #pragma unused(type)
641 char buf[IFNAMSIZ];
642 int unit, i;
643 struct ifnet_traffic_descriptor_inet *tdi;
644 uint8_t mask = 0, ipver, proto;
645
646 if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
647 SK_ERR("invalid ifname: %s", ifname);
648 return EINVAL;
649 }
650 if (td->itd_len != sizeof(*tdi)) {
651 SK_ERR("invalid td len: expected %d, actual %d",
652 sizeof(*tdi), td->itd_len);
653 return EINVAL;
654 }
655 if (td->itd_flags == 0 ||
656 (td->itd_flags &
657 ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
658 IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
659 SK_ERR("invalid td flags: 0x%x", td->itd_flags);
660 return EINVAL;
661 }
662 tdi = (struct ifnet_traffic_descriptor_inet *)td;
663 for (i = 0; i < NINETRULEMASKS; i++) {
664 if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
665 mask = tdi->inet_mask;
666 break;
667 }
668 }
669 if (mask == 0) {
670 SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
671 return EINVAL;
672 }
673 ipver = tdi->inet_ipver;
674 if (ipver != IPVERSION && ipver != IPV6_VERSION) {
675 SK_ERR("invalid inet ipver: 0x%x", ipver);
676 return EINVAL;
677 }
678 proto = tdi->inet_proto;
679 if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
680 SK_ERR("invalid inet proto: %d", proto);
681 return EINVAL;
682 }
683 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
684 if (ipver == IPVERSION) {
685 if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
686 SK_ERR("inet laddr v4 cannot be unspecified");
687 return EINVAL;
688 }
689 } else {
690 if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
691 SK_ERR("inet laddr v4 cannot be unspecified");
692 return EINVAL;
693 }
694 }
695 }
696 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
697 if (ipver == IPVERSION) {
698 if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
699 SK_ERR("inet raddr v6 cannot be unspecified");
700 return EINVAL;
701 }
702 } else {
703 if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
704 SK_ERR("inet raddr v6 cannot be unspecified");
705 return EINVAL;
706 }
707 }
708 }
709 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
710 if (tdi->inet_lport == 0) {
711 SK_ERR("inet lport cannot be unspecified");
712 return EINVAL;
713 }
714 }
715 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
716 if (tdi->inet_rport == 0) {
717 SK_ERR("inet rport cannot be unspecified");
718 return EINVAL;
719 }
720 }
721 if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
722 SK_ERR("invalid ra len: expected %d, actual %d",
723 sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
724 return EINVAL;
725 }
726 return 0;
727 }
728
729 SK_NO_INLINE_ATTRIBUTE
730 static struct nxctl_traffic_rule_inet_storage *
inet_traffic_rule_storage_create(void)731 inet_traffic_rule_storage_create(void)
732 {
733 struct nxctl_traffic_rule_inet_storage *rs;
734
735 rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
736 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
737 SLIST_INIT(&rs->ris_if_list);
738 rs->ris_count = 0;
739 return rs;
740 }
741
742 SK_NO_INLINE_ATTRIBUTE
743 static void
inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage * rs)744 inet_traffic_rule_storage_destroy(struct nxctl_traffic_rule_inet_storage *rs)
745 {
746 ASSERT(rs->ris_count == 0);
747 ASSERT(SLIST_EMPTY(&rs->ris_if_list));
748 sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
749 }
750
751 SK_NO_INLINE_ATTRIBUTE
752 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)753 inet_traffic_rule_if_create(const char *ifname)
754 {
755 struct nxctl_traffic_rule_inet_if *rif;
756 int i;
757
758 rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
759 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
760 for (i = 0; i < NINETRULEMASKS; i++) {
761 SLIST_INIT(&rif->rii_lists[i]);
762 }
763 strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
764 rif->rii_count = 0;
765 return rif;
766 }
767
768 SK_NO_INLINE_ATTRIBUTE
769 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)770 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
771 {
772 int i;
773
774 for (i = 0; i < NINETRULEMASKS; i++) {
775 ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
776 }
777 ASSERT(rif->rii_count == 0);
778 sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
779 }
780
781 SK_NO_INLINE_ATTRIBUTE
782 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)783 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
784 uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
785 {
786 struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
787 struct ifnet_traffic_descriptor_inet *tdi0;
788 uint8_t mask;
789 boolean_t exact;
790
791 VERIFY(strcmp(ntr->ntr_ifname, ifname) == 0);
792 tdi0 = &ntri->ntri_td;
793
794 exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
795 mask = tdi0->inet_mask & tdi->inet_mask;
796 if (exact) {
797 ASSERT(tdi0->inet_mask == tdi->inet_mask);
798 }
799 ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
800 if (tdi0->inet_ipver != tdi->inet_ipver) {
801 DTRACE_SKYWALK2(ipver__mismatch,
802 uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
803 return FALSE;
804 }
805 ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0);
806 if (tdi0->inet_proto != tdi->inet_proto) {
807 DTRACE_SKYWALK2(proto__mismatch,
808 uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
809 return FALSE;
810 }
811 if (tdi0->inet_ipver == IPVERSION) {
812 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
813 tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
814 DTRACE_SKYWALK2(v4laddr__mismatch,
815 in_addr_t, tdi0->inet_laddr.iia_v4addr,
816 in_addr_t, tdi->inet_laddr.iia_v4addr);
817 return FALSE;
818 }
819 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
820 tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
821 DTRACE_SKYWALK2(v4raddr__mismatch,
822 in_addr_t, tdi0->inet_raddr.iia_v4addr,
823 in_addr_t, tdi->inet_raddr.iia_v4addr);
824 return FALSE;
825 }
826 } else {
827 ASSERT(tdi0->inet_ipver == IPV6_VERSION);
828 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
829 !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
830 DTRACE_SKYWALK2(v6laddr__mismatch,
831 struct in6_addr *, &tdi0->inet_laddr,
832 struct in6_addr *, &tdi->inet_laddr);
833 return FALSE;
834 }
835 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
836 !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
837 DTRACE_SKYWALK2(v6raddr__mismatch,
838 struct in6_addr *, &tdi0->inet_raddr,
839 struct in6_addr *, &tdi->inet_raddr);
840 return FALSE;
841 }
842 }
843 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
844 tdi0->inet_lport != tdi->inet_lport) {
845 DTRACE_SKYWALK2(lport__mismatch,
846 uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
847 return FALSE;
848 }
849 if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
850 tdi0->inet_rport != tdi->inet_rport) {
851 DTRACE_SKYWALK2(rport__mismatch,
852 uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
853 return FALSE;
854 }
855 return TRUE;
856 }
857
858 static int
inet_traffic_rule_find(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)859 inet_traffic_rule_find(struct nxctl_traffic_rule_type *type, const char *ifname,
860 struct ifnet_traffic_descriptor_common *td, uint32_t flags,
861 struct nxctl_traffic_rule **ntrp)
862 {
863 struct nxctl_traffic_rule_inet *ntri = NULL;
864 struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
865 struct nxctl_traffic_rule_inet_if *rif;
866 struct ifnet_traffic_descriptor_inet *tdi =
867 (struct ifnet_traffic_descriptor_inet *)td;
868 int i;
869
870 if (rs == NULL) {
871 return ENOENT;
872 }
873 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
874 if (strcmp(rif->rii_ifname, ifname) != 0) {
875 continue;
876 }
877 for (i = 0; i < NINETRULEMASKS; i++) {
878 if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
879 tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
880 continue;
881 }
882 SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
883 if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
884 *ntrp = (struct nxctl_traffic_rule *)ntri;
885 return 0;
886 }
887 }
888 }
889 }
890 return ENOENT;
891 }
892
893 static int
inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type * type,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)894 inet_traffic_rule_find_by_uuid(struct nxctl_traffic_rule_type *type,
895 uuid_t uuid, struct nxctl_traffic_rule **ntrp)
896 {
897 struct nxctl_traffic_rule_inet *ntri;
898 struct nxctl_traffic_rule *ntr;
899 struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
900 struct nxctl_traffic_rule_inet_if *rif;
901 int i;
902
903 if (rs == NULL) {
904 return ENOENT;
905 }
906 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
907 for (i = 0; i < NINETRULEMASKS; i++) {
908 SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
909 ntr = &ntri->ntri_common;
910 if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
911 *ntrp = ntr;
912 return 0;
913 }
914 }
915 }
916 }
917 return ENOENT;
918 }
919
920 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)921 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
922 {
923 struct ifnet *ifp;
924
925 ifp = ifunit_ref(ifname);
926 if (ifp == NULL) {
927 DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
928 return;
929 }
930 ifnet_update_traffic_rule_count(ifp, count);
931 ifnet_decr_iorefcnt(ifp);
932 }
933
934 static void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)935 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
936 {
937 struct nxctl_traffic_rule_type *type = ntr->ntr_type;
938 struct nxctl_traffic_rule_inet_storage *rs;
939 struct nxctl_traffic_rule_inet_if *rif;
940 struct nxctl_traffic_rule_inet *ntri =
941 (struct nxctl_traffic_rule_inet *)ntr;
942 struct nxctl_traffic_rule_inet_head *list = NULL;
943 int i;
944
945 if ((rs = type->ntrt_storage) == NULL) {
946 rs = inet_traffic_rule_storage_create();
947 type->ntrt_storage = rs;
948 }
949 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
950 if (strcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
951 break;
952 }
953 }
954 if (rif == NULL) {
955 rif = inet_traffic_rule_if_create(ntr->ntr_ifname);
956 SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
957 }
958 for (i = 0; i < NINETRULEMASKS; i++) {
959 if (ntri->ntri_td.inet_mask ==
960 nxctl_inet_traffic_rule_masks[i]) {
961 list = &rif->rii_lists[i];
962 break;
963 }
964 }
965 retain_traffic_rule(ntr);
966 ASSERT(list != NULL);
967 SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
968 /* per-interface count */
969 rif->rii_count++;
970 inet_update_ifnet_traffic_rule_count(rif->rii_ifname, rif->rii_count);
971
972 /* global count */
973 rs->ris_count++;
974 }
975
976 static void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)977 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
978 {
979 struct nxctl_traffic_rule_inet_storage *rs;
980 struct nxctl_traffic_rule_inet_if *rif;
981 struct nxctl_traffic_rule_inet *ntri =
982 (struct nxctl_traffic_rule_inet *)ntr;
983 struct nxctl_traffic_rule_inet_head *list = NULL;
984 struct nxctl_traffic_rule_type *type;
985 int i;
986
987 type = ntr->ntr_type;
988 rs = type->ntrt_storage;
989 ASSERT(rs != NULL);
990 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
991 if (strcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
992 break;
993 }
994 }
995 ASSERT(rif != NULL);
996 for (i = 0; i < NINETRULEMASKS; i++) {
997 if (ntri->ntri_td.inet_mask ==
998 nxctl_inet_traffic_rule_masks[i]) {
999 list = &rif->rii_lists[i];
1000 break;
1001 }
1002 }
1003 ASSERT(list != NULL);
1004 SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
1005 rif->rii_count--;
1006 inet_update_ifnet_traffic_rule_count(rif->rii_ifname, rif->rii_count);
1007
1008 rs->ris_count--;
1009 release_traffic_rule(ntr);
1010
1011 if (rif->rii_count == 0) {
1012 SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
1013 inet_traffic_rule_if_destroy(rif);
1014 }
1015 if (rs->ris_count == 0) {
1016 type->ntrt_storage = NULL;
1017 inet_traffic_rule_storage_destroy(rs);
1018 }
1019 }
1020
1021 /*
1022 * XXX
1023 * This may need additional changes to ensure safety against detach/attach.
1024 * This is not an issue for the first consumer of llink interfaces, cellular,
1025 * which does not detach.
1026 */
1027 static int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)1028 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
1029 {
1030 struct ifnet *ifp;
1031 struct nx_netif *nif;
1032 struct netif_qset *qset = NULL;
1033 struct nxctl_traffic_rule_inet *ntri;
1034 int err = 0;
1035
1036 ifp = ifunit_ref(ntr->ntr_ifname);
1037 if (ifp == NULL) {
1038 DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
1039 err = ENXIO;
1040 goto done;
1041 }
1042 nif = NA(ifp)->nifna_netif;
1043 if (!NX_LLINK_PROV(nif->nif_nx)) {
1044 DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
1045 err = ENOTSUP;
1046 goto done;
1047 }
1048 ntri = (struct nxctl_traffic_rule_inet *)ntr;
1049 qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
1050 err = nx_netif_notify_steering_info(nif, qset,
1051 (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
1052 ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
1053 done:
1054 if (qset != NULL) {
1055 nx_netif_qset_release(&qset);
1056 }
1057 if (ifp != NULL) {
1058 ifnet_decr_iorefcnt(ifp);
1059 }
1060 return err;
1061 }
1062
1063 static int
inet_traffic_rule_create(struct nxctl_traffic_rule_type * type,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1064 inet_traffic_rule_create(struct nxctl_traffic_rule_type *type,
1065 const char *ifname, struct ifnet_traffic_descriptor_common *td,
1066 struct ifnet_traffic_rule_action *ra, uint32_t flags,
1067 struct nxctl_traffic_rule **ntrp)
1068 {
1069 struct nxctl_traffic_rule_inet *ntri;
1070 struct nxctl_traffic_rule *ntr;
1071
1072 ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
1073 Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
1074 ntr = &ntri->ntri_common;
1075
1076 ntr->ntr_type = type;
1077 ntr->ntr_flags = flags;
1078 uuid_generate(ntr->ntr_uuid);
1079 os_ref_init(&ntr->ntr_refcnt, NULL);
1080
1081 strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
1082 proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
1083 bcopy(td, &ntri->ntri_td, sizeof(ntri->ntri_td));
1084 bcopy(ra, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
1085
1086 *ntrp = ntr;
1087 return 0;
1088 }
1089
1090 static void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)1091 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
1092 {
1093 struct nxctl_traffic_rule_inet *ntri;
1094
1095 ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
1096 ntri = (struct nxctl_traffic_rule_inet *)ntr;
1097 sk_free_type(struct nxctl_traffic_rule_inet, ntri);
1098 }
1099
1100 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)1101 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
1102 struct nxctl_traffic_rule_inet_iocinfo *info)
1103 {
1104 struct nxctl_traffic_rule *ntr;
1105 struct nxctl_traffic_rule_generic_iocinfo *ginfo;
1106
1107 bzero(info, sizeof(*info));
1108 ntr = &ntri->ntri_common;
1109 ginfo = &info->tri_common;
1110 _CASSERT(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
1111 _CASSERT(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
1112 uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
1113 strlcpy(ginfo->trg_procname, ntr->ntr_procname,
1114 sizeof(ginfo->trg_procname));
1115 strlcpy(ginfo->trg_ifname, ntr->ntr_ifname,
1116 sizeof(ginfo->trg_ifname));
1117 bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
1118 bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
1119 }
1120
1121 static int
inet_traffic_rule_get_all(struct nxctl_traffic_rule_type * type,uint32_t size,uint32_t * count,user_addr_t uaddr)1122 inet_traffic_rule_get_all(struct nxctl_traffic_rule_type *type, uint32_t size,
1123 uint32_t *count, user_addr_t uaddr)
1124 {
1125 struct nxctl_traffic_rule_inet *ntri = NULL;
1126 struct nxctl_traffic_rule_inet_storage *rs = type->ntrt_storage;
1127 struct nxctl_traffic_rule_inet_if *rif;
1128 struct nxctl_traffic_rule_inet_iocinfo info;
1129 int i, err;
1130
1131 if (size != sizeof(info)) {
1132 SK_ERR("size: actual %d, expected %d", size, sizeof(info));
1133 return EINVAL;
1134 }
1135 if (rs == NULL) {
1136 *count = 0;
1137 return 0;
1138 }
1139 if (*count < rs->ris_count) {
1140 SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
1141 return ENOBUFS;
1142 }
1143 SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
1144 for (i = 0; i < NINETRULEMASKS; i++) {
1145 SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
1146 convert_ntri_to_iocinfo(ntri, &info);
1147 err = copyout(&info, uaddr, sizeof(info));
1148 if (err != 0) {
1149 SK_ERR("copyout failed: %d", err);
1150 return err;
1151 }
1152 uaddr += sizeof(info);
1153 }
1154 }
1155 }
1156 *count = rs->ris_count;
1157 return 0;
1158 }
1159
1160 SK_NO_INLINE_ATTRIBUTE
1161 static void
retain_traffic_rule(struct nxctl_traffic_rule * ntr)1162 retain_traffic_rule(struct nxctl_traffic_rule *ntr)
1163 {
1164 #if (DEVELOPMENT || DEBUG)
1165 os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1166 DTRACE_SKYWALK2(ntr__retain, struct nxctl_traffic_rule *, ntr,
1167 os_ref_count_t, count);
1168 #endif
1169 os_ref_retain(&ntr->ntr_refcnt);
1170 }
1171
1172 SK_NO_INLINE_ATTRIBUTE
1173 static void
release_traffic_rule(struct nxctl_traffic_rule * ntr)1174 release_traffic_rule(struct nxctl_traffic_rule *ntr)
1175 {
1176 #if (DEVELOPMENT || DEBUG)
1177 os_ref_count_t count = os_ref_get_count(&ntr->ntr_refcnt);
1178 DTRACE_SKYWALK2(ntr__release, struct nxctl_traffic_rule *, ntr,
1179 os_ref_count_t, count);
1180 #endif
1181 if (os_ref_release(&ntr->ntr_refcnt) == 0) {
1182 ntr->ntr_type->ntrt_destroy(ntr);
1183 }
1184 }
1185
1186 SK_NO_INLINE_ATTRIBUTE
1187 static int
notify_traffic_rule(struct nxctl_traffic_rule * ntr,uint32_t flags)1188 notify_traffic_rule(struct nxctl_traffic_rule *ntr, uint32_t flags)
1189 {
1190 return ntr->ntr_type->ntrt_notify(ntr, flags);
1191 }
1192
1193 static void
link_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1194 link_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1195 {
1196 /*
1197 * The persist flag means: do not clean up rule upon nxctl fd close.
1198 * This means we only add the rule to the nxctl list if persist
1199 * is not set.
1200 */
1201 if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1202 add_traffic_rule_to_nxctl(nxctl, ntr);
1203 }
1204 ntr->ntr_type->ntrt_link(ntr);
1205 }
1206
1207 static void
unlink_traffic_rule(struct nxctl * nxctl,struct nxctl_traffic_rule * ntr)1208 unlink_traffic_rule(struct nxctl *nxctl, struct nxctl_traffic_rule *ntr)
1209 {
1210 if ((ntr->ntr_flags & NTR_FLAG_PERSIST) == 0) {
1211 remove_traffic_rule_from_nxctl(nxctl, ntr);
1212 }
1213 ntr->ntr_type->ntrt_unlink(ntr);
1214 }
1215
1216 static int
find_traffic_rule_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1217 find_traffic_rule_by_uuid(uuid_t uuid, struct nxctl_traffic_rule **ntrp)
1218 {
1219 int i, err;
1220 struct nxctl_traffic_rule_type *ntrt;
1221 struct nxctl_traffic_rule *ntr = NULL;
1222
1223 for (i = 0; i < NRULETYPES; i++) {
1224 ntrt = &nxctl_rule_types[i];
1225 err = ntrt->ntrt_find_by_uuid(ntrt, uuid, &ntr);
1226 if (err == 0) {
1227 ASSERT(ntr != NULL);
1228 *ntrp = ntr;
1229 return 0;
1230 }
1231 }
1232 return ENOENT;
1233 }
1234
1235 static struct nxctl_traffic_rule_type *
find_traffic_rule_type(uint8_t type)1236 find_traffic_rule_type(uint8_t type)
1237 {
1238 int i;
1239 struct nxctl_traffic_rule_type *ntrt;
1240
1241 for (i = 0; i < NRULETYPES; i++) {
1242 ntrt = &nxctl_rule_types[i];
1243 if (ntrt->ntrt_type == type) {
1244 return ntrt;
1245 }
1246 }
1247 return NULL;
1248 }
1249
1250 SK_NO_INLINE_ATTRIBUTE
1251 static int
add_traffic_rule(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)1252 add_traffic_rule(struct nxctl *nxctl, const char *ifname,
1253 struct ifnet_traffic_descriptor_common *td,
1254 struct ifnet_traffic_rule_action *ra,
1255 uint32_t flags,
1256 struct nxctl_traffic_rule **ntrp)
1257 {
1258 struct nxctl_traffic_rule_type *type = NULL;
1259 struct nxctl_traffic_rule *ntr = NULL;
1260 int err;
1261
1262 NXTR_WLOCK();
1263 type = find_traffic_rule_type(td->itd_type);
1264 if (type == NULL) {
1265 SK_ERR("rule type %x not found", td->itd_type);
1266 err = EINVAL;
1267 goto fail;
1268 }
1269 err = type->ntrt_validate(type, ifname, td, ra);
1270 if (err != 0) {
1271 SK_ERR("rule validate failed: %d", err);
1272 goto fail;
1273 }
1274 err = type->ntrt_find(type, ifname, td, NTR_FIND_FLAG_EXACT, &ntr);
1275 if (err == 0) {
1276 SK_ERR("rule already exists");
1277 ASSERT(ntr != NULL);
1278 err = EEXIST;
1279 goto fail;
1280 } else if (err != ENOENT) {
1281 SK_ERR("rule find failed: %d", err);
1282 goto fail;
1283 }
1284 err = type->ntrt_create(type, ifname, td, ra, flags, &ntr);
1285 if (err != 0) {
1286 SK_ERR("rule create failed: %d", err);
1287 goto fail;
1288 }
1289 link_traffic_rule(nxctl, ntr);
1290 if (ntrp != NULL) {
1291 retain_traffic_rule(ntr);
1292 *ntrp = ntr;
1293 }
1294 NXTR_WUNLOCK();
1295 return 0;
1296 fail:
1297 NXTR_WUNLOCK();
1298 return err;
1299 }
1300
1301
1302 SK_NO_INLINE_ATTRIBUTE
1303 static int
remove_traffic_rule(struct nxctl * nxctl,uuid_t uuid,struct nxctl_traffic_rule ** ntrp)1304 remove_traffic_rule(struct nxctl *nxctl, uuid_t uuid,
1305 struct nxctl_traffic_rule **ntrp)
1306 {
1307 struct nxctl_traffic_rule *ntr;
1308 int err;
1309
1310 NXTR_WLOCK();
1311 err = find_traffic_rule_by_uuid(uuid, &ntr);
1312 if (err != 0) {
1313 SK_ERR("traffic rule not found");
1314 NXTR_WUNLOCK();
1315 return err;
1316 }
1317 if (ntrp != NULL) {
1318 retain_traffic_rule(ntr);
1319 *ntrp = ntr;
1320 }
1321 unlink_traffic_rule(nxctl, ntr);
1322 /* release initial reference */
1323 release_traffic_rule(ntr);
1324 NXTR_WUNLOCK();
1325 return 0;
1326 }
1327
1328 static uint32_t
convert_traffic_rule_ioc_flags(uint32_t flags)1329 convert_traffic_rule_ioc_flags(uint32_t flags)
1330 {
1331 uint32_t f = 0;
1332
1333 if ((flags & NXIOC_ADD_TRAFFIC_RULE_FLAG_PERSIST) != 0) {
1334 f |= NTR_FLAG_PERSIST;
1335 }
1336 return f;
1337 }
1338
1339 SK_NO_INLINE_ATTRIBUTE
1340 static int
add_traffic_rule_generic(struct nxctl * nxctl,const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,uuid_t * uuid)1341 add_traffic_rule_generic(struct nxctl *nxctl, const char *ifname,
1342 struct ifnet_traffic_descriptor_common *td,
1343 struct ifnet_traffic_rule_action *ra, uint32_t flags, uuid_t *uuid)
1344 {
1345 struct nxctl_traffic_rule *ntr;
1346 int err;
1347
1348 err = add_traffic_rule(nxctl, ifname, td, ra, flags, &ntr);
1349 if (err != 0) {
1350 return err;
1351 }
1352 (void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_ADD);
1353 uuid_copy(*uuid, ntr->ntr_uuid);
1354 release_traffic_rule(ntr);
1355 return 0;
1356 }
1357
1358 int
nxioctl_add_traffic_rule_inet(struct nxctl * nxctl,caddr_t data,proc_t procp)1359 nxioctl_add_traffic_rule_inet(struct nxctl *nxctl, caddr_t data, proc_t procp)
1360 {
1361 #pragma unused(procp)
1362 struct nxctl_add_traffic_rule_inet_iocargs *args =
1363 (struct nxctl_add_traffic_rule_inet_iocargs *)(void *)data;
1364
1365 return add_traffic_rule_generic(nxctl, args->atri_ifname,
1366 (struct ifnet_traffic_descriptor_common *)&args->atri_td,
1367 (struct ifnet_traffic_rule_action *)&args->atri_ra,
1368 convert_traffic_rule_ioc_flags(args->atri_flags),
1369 &args->atri_uuid);
1370 }
1371
1372 int
nxioctl_remove_traffic_rule(struct nxctl * nxctl,caddr_t data,proc_t procp)1373 nxioctl_remove_traffic_rule(struct nxctl *nxctl, caddr_t data, proc_t procp)
1374 {
1375 #pragma unused(procp)
1376 struct nxctl_remove_traffic_rule_iocargs *args =
1377 (struct nxctl_remove_traffic_rule_iocargs *)(void *)data;
1378 struct nxctl_traffic_rule *ntr;
1379 int err;
1380
1381 err = remove_traffic_rule(nxctl, args->rtr_uuid, &ntr);
1382 if (err != 0) {
1383 return err;
1384 }
1385 (void) notify_traffic_rule(ntr, NTR_NOTIFY_FLAG_REMOVE);
1386 release_traffic_rule(ntr);
1387 return 0;
1388 }
1389
1390 int
nxioctl_get_traffic_rules(struct nxctl * nxctl,caddr_t data,proc_t procp)1391 nxioctl_get_traffic_rules(struct nxctl *nxctl, caddr_t data, proc_t procp)
1392 {
1393 #pragma unused(nxctl)
1394 struct nxctl_get_traffic_rules_iocargs *args =
1395 (struct nxctl_get_traffic_rules_iocargs *)(void *)data;
1396 struct nxctl_traffic_rule_type *type;
1397 user_addr_t uaddr;
1398 int err;
1399
1400 NXTR_RLOCK();
1401 type = find_traffic_rule_type(args->gtr_type);
1402 if (type == NULL) {
1403 SK_ERR("rule type %x not found", args->gtr_type);
1404 err = EINVAL;
1405 goto fail;
1406 }
1407 uaddr = proc_is64bit(procp) ? args->gtr_buf64 :
1408 CAST_USER_ADDR_T(args->gtr_buf);
1409 err = type->ntrt_get_all(type, args->gtr_size, &args->gtr_count, uaddr);
1410 if (err != 0) {
1411 goto fail;
1412 }
1413 NXTR_RUNLOCK();
1414 return 0;
1415 fail:
1416 NXTR_RUNLOCK();
1417 return err;
1418 }
1419