xref: /xnu-12377.1.9/bsd/skywalk/nexus/nexus_traffic_rule_inet.c (revision f6217f891ac0bb64f3d375211650a4c1ff8ca1ea)
1 /*
2  * Copyright (c) 2022 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <skywalk/nexus/nexus_traffic_rule_inet.h>
30 #include <skywalk/nexus/netif/nx_netif.h>
31 #include <netinet/ip.h>
32 #include <netinet/ip6.h>
33 #include <netinet/tcp.h>
34 #include <netinet/udp.h>
35 
36 /*
37  * Inet-specific traffic rule.
38  */
39 struct nxctl_traffic_rule_inet {
40 	struct nxctl_traffic_rule ntri_common;
41 	SLIST_ENTRY(nxctl_traffic_rule_inet) ntri_storage_link;
42 	struct ifnet_traffic_descriptor_inet ntri_td;
43 	struct ifnet_traffic_rule_action_steer ntri_ra;
44 };
45 
46 /*
47  * Currently supported tuple types.
48  */
49 #define ITRM(proto, laddr, raddr, lport, rport) \
50 	(IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER | \
51 	ITDBIT(proto, IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) | \
52 	ITDBIT(laddr, IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) | \
53 	ITDBIT(raddr, IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) | \
54 	ITDBIT(lport, IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) | \
55 	ITDBIT(rport, IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT))
56 
57 static uint8_t nxctl_inet_traffic_rule_masks[] = {
58 	ITRM(1, 1, 1, 1, 1),
59 	ITRM(1, 1, 1, 1, 0),
60 	ITRM(1, 1, 1, 0, 1),
61 	ITRM(1, 1, 1, 0, 0),
62 	ITRM(1, 1, 0, 1, 1),
63 	ITRM(1, 1, 0, 1, 0),
64 	ITRM(1, 1, 0, 0, 1),
65 	ITRM(1, 1, 0, 0, 0),
66 	ITRM(1, 0, 1, 1, 1),
67 	ITRM(1, 0, 1, 1, 0),
68 	ITRM(1, 0, 1, 0, 1),
69 	ITRM(1, 0, 1, 0, 0),
70 	ITRM(1, 0, 0, 1, 1),
71 	ITRM(1, 0, 0, 1, 0),
72 	ITRM(1, 0, 0, 0, 1),
73 	// ITRM(1, 0, 0, 0, 0), addr or port is required
74 	ITRM(0, 1, 1, 1, 1),
75 	ITRM(0, 1, 1, 1, 0),
76 	ITRM(0, 1, 1, 0, 1),
77 	ITRM(0, 1, 1, 0, 0),
78 	ITRM(0, 1, 0, 1, 1),
79 	ITRM(0, 1, 0, 1, 0),
80 	ITRM(0, 1, 0, 0, 1),
81 	ITRM(0, 1, 0, 0, 0),
82 	ITRM(0, 0, 1, 1, 1),
83 	ITRM(0, 0, 1, 1, 0),
84 	ITRM(0, 0, 1, 0, 1),
85 	ITRM(0, 0, 1, 0, 0),
86 	ITRM(0, 0, 0, 1, 1),
87 	ITRM(0, 0, 0, 1, 0),
88 	ITRM(0, 0, 0, 0, 1),
89 	// ITRM(0, 0, 0, 0, 0),
90 };
91 #define NINETRULEMASKS \
92     (sizeof(nxctl_inet_traffic_rule_masks)/sizeof(uint8_t))
93 
94 /* Per-interface lists of inet traffic rules */
95 SLIST_HEAD(nxctl_traffic_rule_inet_head, nxctl_traffic_rule_inet);
96 struct nxctl_traffic_rule_inet_if {
97 	char rii_ifname[IFNAMSIZ];
98 	struct nxctl_traffic_rule_inet_head rii_lists[NINETRULEMASKS];
99 	uint32_t rii_count;
100 	SLIST_ENTRY(nxctl_traffic_rule_inet_if) rii_link;
101 };
102 
103 /* List of per-interface lists */
104 SLIST_HEAD(nxctl_traffic_rule_inet_if_head, nxctl_traffic_rule_inet_if);
105 struct nxctl_traffic_rule_inet_storage {
106 	struct nxctl_traffic_rule_inet_if_head ris_if_list;
107 	uint32_t ris_count;
108 };
109 
110 static struct nxctl_traffic_rule_inet_storage *rs = NULL;
111 static kern_allocation_name_t nxctl_traffic_rule_tag = NULL;
112 
113 static boolean_t inet_v6addr_cmp(struct ifnet_ip_addr *a1,
114     struct ifnet_ip_addr *a2);
115 
116 /*
117  * If an interface attaches after rule(s) are added, this function is used
118  * retrieve the current rule count for that interface.
119  */
120 int
nxctl_inet_traffic_rule_get_count(const char * ifname,uint32_t * count)121 nxctl_inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
122 {
123 	int err;
124 
125 	NXTR_RLOCK();
126 	err = inet_traffic_rule_get_count(ifname, count);
127 	NXTR_RUNLOCK();
128 
129 	return err;
130 }
131 
132 /*
133  * Used for finding the qset id associated with a traffic descriptor.
134  */
135 int
nxctl_inet_traffic_rule_find_qset_id(const char * ifname,struct ifnet_traffic_descriptor_inet * td,uint64_t * qset_id)136 nxctl_inet_traffic_rule_find_qset_id(const char *ifname,
137     struct ifnet_traffic_descriptor_inet *td, uint64_t *qset_id)
138 {
139 	struct nxctl_traffic_rule_inet *__single ntri = NULL;
140 	struct nxctl_traffic_rule *__single ntr = NULL;
141 	int err;
142 
143 	NXTR_RLOCK();
144 	err = inet_traffic_rule_find(ifname, &td->inet_common, 0, &ntr);
145 	if (err != 0) {
146 		goto fail;
147 	}
148 	ntri = __container_of(ntr, struct nxctl_traffic_rule_inet, ntri_common);
149 	*qset_id = ntri->ntri_ra.ras_qset_id;
150 	NXTR_RUNLOCK();
151 	return 0;
152 fail:
153 	NXTR_RUNLOCK();
154 	return err;
155 }
156 
157 /*
158  * Based on flow_pkt_classify().
159  * This function populates struct ifnet_traffic_descriptor_inet instead of struct __flow.
160  */
161 static int
fill_inet_td(struct __kern_packet * pkt,struct ifnet_traffic_descriptor_inet * td)162 fill_inet_td(struct __kern_packet *pkt, struct ifnet_traffic_descriptor_inet *td)
163 {
164 	union {
165 		volatile struct ip *__indexable _iph;
166 		volatile struct ip6_hdr *__indexable _ip6;
167 	} _l3;
168 	#define iph _l3._iph
169 	#define ip6 _l3._ip6
170 	union {
171 		volatile struct tcphdr *_tcph;
172 		volatile struct udphdr *_udph;
173 	} _l4;
174 	#define tcph _l4._tcph
175 	#define udph _l4._udph
176 	uint8_t *pkt_buf, *l3_hdr;
177 	uint32_t bdlen, bdlim, bdoff, cls_len;
178 	size_t pkt_len;
179 	uint8_t ipv, l3hlen = 0; /* IP header length */
180 	uint16_t l3tlen = 0;     /* total length of IP packet */
181 	uint8_t l4hlen = 0;      /* TCP/UDP header length */
182 	uint16_t ulen = 0;       /* user data length */
183 	int err;
184 
185 	ASSERT(pkt->pkt_l2_len <= pkt->pkt_length);
186 	pkt_len = pkt->pkt_length - pkt->pkt_l2_len;
187 
188 	MD_BUFLET_ADDR_ABS_DLEN(pkt, pkt_buf, bdlen, bdlim, bdoff);
189 	cls_len = bdlim - bdoff;
190 	cls_len -= pkt->pkt_l2_len;
191 	cls_len = (uint32_t)MIN(cls_len, pkt_len);
192 	VERIFY(pkt_len >= cls_len);
193 	if (cls_len == 0) {
194 		SK_ERR("cls_len == 0");
195 		err = EINVAL;
196 		goto fail;
197 	}
198 	l3_hdr = pkt_buf + pkt->pkt_headroom + pkt->pkt_l2_len;
199 	iph = (volatile struct ip *)(void *)l3_hdr;
200 	ipv = iph->ip_v;
201 
202 	switch (ipv) {
203 	case 4:
204 		if (cls_len < sizeof(struct ip)) {
205 			SK_ERR("cls_len < sizeof(struct ip) (%d < %lu)",
206 			    cls_len, sizeof(struct ip));
207 			err = EINVAL;
208 			goto fail;
209 		}
210 		l3hlen = (uint8_t)(iph->ip_hl << 2);
211 		if (l3hlen < sizeof(struct ip)) {
212 			SK_ERR("l3hlen < sizeof(struct ip) (%d < %lu)",
213 			    l3hlen, sizeof(struct ip));
214 			err = EINVAL;
215 			goto fail;
216 		}
217 		if (cls_len < l3hlen) {
218 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
219 			err = EINVAL;
220 			goto fail;
221 		}
222 		l3tlen = ntohs(iph->ip_len);
223 		if (l3tlen < l3hlen) {
224 			SK_ERR("l3tlen < l3hlen (%d < %d)", l3tlen, l3hlen);
225 			err = EINVAL;
226 			goto fail;
227 		}
228 		if (pkt_len < l3tlen) {
229 			SK_ERR("pkt_len < l3tlen (%zu < %d)", pkt_len, l3tlen);
230 			err = EINVAL;
231 			goto fail;
232 		}
233 		td->inet_ipver = IPVERSION;
234 		td->inet_proto = iph->ip_p;
235 		bcopy(__DECONST(void *, &iph->ip_src), &td->inet_laddr.iia_v4addr,
236 		    sizeof(iph->ip_src));
237 		bcopy(__DECONST(void *, &iph->ip_dst), &td->inet_raddr.iia_v4addr,
238 		    sizeof(iph->ip_dst));
239 		break;
240 	case 6:
241 		l3hlen = sizeof(struct ip6_hdr);
242 		if (cls_len < l3hlen) {
243 			SK_ERR("cls_len < l3hlen (%d < %d)", cls_len, l3hlen);
244 			err = EINVAL;
245 			goto fail;
246 		}
247 		l3tlen = l3hlen + ntohs(ip6->ip6_plen);
248 		if (pkt_len < l3tlen) {
249 			SK_ERR("pkt_len < l3tlen (%zu < %d)", pkt_len, l3tlen);
250 			err = EINVAL;
251 			goto fail;
252 		}
253 		td->inet_ipver = IPV6_VERSION;
254 		td->inet_proto = ip6->ip6_nxt;
255 		bcopy(__DECONST(void *, &ip6->ip6_src), &td->inet_laddr,
256 		    sizeof(ip6->ip6_src));
257 		bcopy(__DECONST(void *, &ip6->ip6_dst), &td->inet_raddr,
258 		    sizeof(ip6->ip6_dst));
259 		break;
260 	default:
261 		SK_ERR("ipv == %d", ipv);
262 		err = EINVAL;
263 		goto fail;
264 	}
265 	tcph = __DECONST(volatile struct tcphdr *, (volatile uint8_t *)iph + l3hlen);
266 	ulen = (l3tlen - l3hlen);
267 	if (td->inet_proto == IPPROTO_TCP) {
268 		if (cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph)) {
269 			SK_ERR("cls_len < l3hlen + sizeof(*tcph) || ulen < sizeof(*tcph) "
270 			    "(%d < %d + %lu || %d < %lu)", cls_len, l3hlen, sizeof(*tcph),
271 			    ulen, sizeof(*tcph));
272 			err = EINVAL;
273 			goto fail;
274 		}
275 		l4hlen = (uint8_t)(tcph->th_off << 2);
276 		if (l4hlen < sizeof(*tcph)) {
277 			SK_ERR("l4hlen < sizeof(*tcph) (%d < %lu)", l4hlen, sizeof(*tcph));
278 			err = EINVAL;
279 			goto fail;
280 		}
281 		if (l4hlen > ulen) {
282 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
283 			err = EINVAL;
284 			goto fail;
285 		}
286 		bcopy(__DECONST(void *, &tcph->th_sport), &td->inet_lport,
287 		    sizeof(td->inet_lport));
288 		bcopy(__DECONST(void *, &tcph->th_dport), &td->inet_rport,
289 		    sizeof(td->inet_rport));
290 	} else if (td->inet_proto == IPPROTO_UDP) {
291 		if (cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph)) {
292 			SK_ERR("cls_len < l3hlen + sizeof(*udph) || ulen < sizeof(*udph) "
293 			    "(%d < %d + %lu || %d < %lu)", cls_len, l3hlen, sizeof(*udph),
294 			    ulen, sizeof(*udph));
295 			err = EINVAL;
296 			goto fail;
297 		}
298 		l4hlen = sizeof(*udph);
299 		if (l4hlen > ulen) {
300 			SK_ERR("l4hlen > ulen (%d > %d)", l4hlen, ulen);
301 			err = EINVAL;
302 			goto fail;
303 		}
304 		bcopy(__DECONST(void *, &udph->uh_sport), &td->inet_lport,
305 		    sizeof(td->inet_lport));
306 		bcopy(__DECONST(void *, &udph->uh_dport), &td->inet_rport,
307 		    sizeof(td->inet_rport));
308 	} else {
309 		err = ENOTSUP;
310 		goto fail;
311 	}
312 
313 	td->inet_common.itd_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
314 	td->inet_common.itd_len = sizeof(*td);
315 	td->inet_common.itd_flags = IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
316 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND;
317 	td->inet_mask |= (IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER |
318 	    IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO |
319 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR |
320 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR |
321 	    IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT |
322 	    IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT);
323 	return 0;
324 fail:
325 	DTRACE_SKYWALK5(classify__failed, struct ip *, iph, size_t, pkt_len,
326 	    uint8_t, pkt->pkt_l2_len, struct ifnet_traffic_descriptor_inet *, td,
327 	    int, err);
328 	bzero(td, sizeof(*td));
329 	return err;
330 	#undef iph
331 	#undef ip6
332 	#undef tcph
333 	#undef udph
334 }
335 
336 int
nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char * ifname,struct __kern_packet * pkt,uint64_t * qset_id)337 nxctl_inet_traffic_rule_find_qset_id_with_pkt(const char *ifname,
338     struct __kern_packet *pkt, uint64_t *qset_id)
339 {
340 	struct ifnet_traffic_descriptor_inet td;
341 	int err;
342 
343 	err = fill_inet_td(pkt, &td);
344 	if (err != 0) {
345 		return err;
346 	}
347 	return nxctl_inet_traffic_rule_find_qset_id(ifname, &td, qset_id);
348 }
349 
350 static struct ifnet_ip_addr v6_zeros_addr = {0};
351 static boolean_t
inet_v6addr_cmp(struct ifnet_ip_addr * a1,struct ifnet_ip_addr * a2)352 inet_v6addr_cmp(struct ifnet_ip_addr *a1, struct ifnet_ip_addr *a2)
353 {
354 	return memcmp(a1, a2, sizeof(*a1)) == 0;
355 }
356 
357 void
inet_traffic_rule_init(kern_allocation_name_t rule_tag)358 inet_traffic_rule_init(kern_allocation_name_t rule_tag)
359 {
360 	ASSERT(nxctl_traffic_rule_tag == NULL);
361 	nxctl_traffic_rule_tag = rule_tag;
362 }
363 
364 int
inet_traffic_rule_validate(const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra)365 inet_traffic_rule_validate(
366 	const char *ifname,
367 	struct ifnet_traffic_descriptor_common *td,
368 	struct ifnet_traffic_rule_action *ra)
369 {
370 	char buf[IFNAMSIZ];
371 	int unit, i;
372 	struct ifnet_traffic_descriptor_inet *tdi;
373 	uint8_t mask = 0, ipver, proto;
374 
375 	if (ifunit_extract(ifname, buf, sizeof(buf), &unit) < 0) {
376 		SK_ERR("invalid ifname: %s", ifname);
377 		return EINVAL;
378 	}
379 	if (td->itd_len != sizeof(*tdi)) {
380 		SK_ERR("invalid td len: expected %lu, actual %d",
381 		    sizeof(*tdi), td->itd_len);
382 		return EINVAL;
383 	}
384 	if (td->itd_flags == 0 ||
385 	    (td->itd_flags &
386 	    ~(IFNET_TRAFFIC_DESCRIPTOR_FLAG_INBOUND |
387 	    IFNET_TRAFFIC_DESCRIPTOR_FLAG_OUTBOUND)) != 0) {
388 		SK_ERR("invalid td flags: 0x%x", td->itd_flags);
389 		return EINVAL;
390 	}
391 	tdi = (struct ifnet_traffic_descriptor_inet *)td;
392 	for (i = 0; i < NINETRULEMASKS; i++) {
393 		if (tdi->inet_mask == nxctl_inet_traffic_rule_masks[i]) {
394 			mask = tdi->inet_mask;
395 			break;
396 		}
397 	}
398 	if (mask == 0) {
399 		SK_ERR("invalid inet mask: 0x%x", tdi->inet_mask);
400 		return EINVAL;
401 	}
402 	ipver = tdi->inet_ipver;
403 	if (ipver != IPVERSION && ipver != IPV6_VERSION) {
404 		SK_ERR("invalid inet ipver: 0x%x", ipver);
405 		return EINVAL;
406 	}
407 	proto = tdi->inet_proto;
408 	if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
409 		SK_ERR("invalid inet proto: %d", proto);
410 		return EINVAL;
411 	}
412 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0) {
413 		if (ipver == IPVERSION) {
414 			if (tdi->inet_laddr.iia_v4addr == INADDR_ANY) {
415 				SK_ERR("inet laddr v4 cannot be unspecified");
416 				return EINVAL;
417 			}
418 		} else {
419 			if (inet_v6addr_cmp(&tdi->inet_laddr, &v6_zeros_addr)) {
420 				SK_ERR("inet laddr v4 cannot be unspecified");
421 				return EINVAL;
422 			}
423 		}
424 	}
425 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0) {
426 		if (ipver == IPVERSION) {
427 			if (tdi->inet_raddr.iia_v4addr == INADDR_ANY) {
428 				SK_ERR("inet raddr v6 cannot be unspecified");
429 				return EINVAL;
430 			}
431 		} else {
432 			if (inet_v6addr_cmp(&tdi->inet_raddr, &v6_zeros_addr)) {
433 				SK_ERR("inet raddr v6 cannot be unspecified");
434 				return EINVAL;
435 			}
436 		}
437 	}
438 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0) {
439 		if (tdi->inet_lport == 0) {
440 			SK_ERR("inet lport cannot be unspecified");
441 			return EINVAL;
442 		}
443 	}
444 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0) {
445 		if (tdi->inet_rport == 0) {
446 			SK_ERR("inet rport cannot be unspecified");
447 			return EINVAL;
448 		}
449 	}
450 	if (ra->ra_len != sizeof(struct ifnet_traffic_rule_action_steer)) {
451 		SK_ERR("invalid ra len: expected %lu, actual %d",
452 		    sizeof(struct ifnet_traffic_rule_action_steer), ra->ra_len);
453 		return EINVAL;
454 	}
455 	return 0;
456 }
457 
458 SK_NO_INLINE_ATTRIBUTE
459 static void
inet_traffic_rule_storage_create(void)460 inet_traffic_rule_storage_create(void)
461 {
462 	rs = sk_alloc_type(struct nxctl_traffic_rule_inet_storage,
463 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
464 	SLIST_INIT(&rs->ris_if_list);
465 	rs->ris_count = 0;
466 	return;
467 }
468 
469 SK_NO_INLINE_ATTRIBUTE
470 static void
inet_traffic_rule_storage_destroy(void)471 inet_traffic_rule_storage_destroy(void)
472 {
473 	ASSERT(rs->ris_count == 0);
474 	ASSERT(SLIST_EMPTY(&rs->ris_if_list));
475 	sk_free_type(struct nxctl_traffic_rule_inet_storage, rs);
476 }
477 
478 SK_NO_INLINE_ATTRIBUTE
479 static struct nxctl_traffic_rule_inet_if *
inet_traffic_rule_if_create(const char * ifname)480 inet_traffic_rule_if_create(const char *ifname)
481 {
482 	struct nxctl_traffic_rule_inet_if *rif;
483 	int i;
484 
485 	rif = sk_alloc_type(struct nxctl_traffic_rule_inet_if,
486 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
487 	for (i = 0; i < NINETRULEMASKS; i++) {
488 		SLIST_INIT(&rif->rii_lists[i]);
489 	}
490 	strlcpy(rif->rii_ifname, ifname, sizeof(rif->rii_ifname));
491 	rif->rii_count = 0;
492 	return rif;
493 }
494 
495 SK_NO_INLINE_ATTRIBUTE
496 static void
inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if * rif)497 inet_traffic_rule_if_destroy(struct nxctl_traffic_rule_inet_if *rif)
498 {
499 	int i;
500 
501 	for (i = 0; i < NINETRULEMASKS; i++) {
502 		ASSERT(SLIST_EMPTY(&rif->rii_lists[i]));
503 	}
504 	ASSERT(rif->rii_count == 0);
505 	sk_free_type(struct nxctl_traffic_rule_inet_if, rif);
506 }
507 
508 SK_NO_INLINE_ATTRIBUTE
509 static boolean_t
inet_traffic_rule_match(struct nxctl_traffic_rule_inet * ntri,const char * ifname,uint32_t flags,struct ifnet_traffic_descriptor_inet * tdi)510 inet_traffic_rule_match(struct nxctl_traffic_rule_inet *ntri, const char *ifname,
511     uint32_t flags, struct ifnet_traffic_descriptor_inet *tdi)
512 {
513 	struct nxctl_traffic_rule *ntr = (struct nxctl_traffic_rule *)ntri;
514 	struct ifnet_traffic_descriptor_inet *tdi0;
515 	uint8_t mask;
516 	boolean_t exact;
517 
518 	VERIFY(strlcmp(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname)) == 0);
519 	tdi0 = &ntri->ntri_td;
520 
521 	exact = ((flags & NTR_FIND_FLAG_EXACT) != 0);
522 	mask = tdi0->inet_mask & tdi->inet_mask;
523 	if (exact) {
524 		ASSERT(tdi0->inet_mask == tdi->inet_mask);
525 	}
526 	ASSERT((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_IPVER) != 0);
527 	if (tdi0->inet_ipver != tdi->inet_ipver) {
528 		DTRACE_SKYWALK2(ipver__mismatch,
529 		    uint8_t, tdi0->inet_ipver, uint8_t, tdi->inet_ipver);
530 		return FALSE;
531 	}
532 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_PROTO) != 0 &&
533 	    tdi0->inet_proto != tdi->inet_proto) {
534 		DTRACE_SKYWALK2(proto__mismatch,
535 		    uint8_t, tdi0->inet_proto, uint8_t, tdi->inet_proto);
536 		return FALSE;
537 	}
538 	if (tdi0->inet_ipver == IPVERSION) {
539 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
540 		    tdi0->inet_laddr.iia_v4addr != tdi->inet_laddr.iia_v4addr) {
541 			DTRACE_SKYWALK2(v4laddr__mismatch,
542 			    in_addr_t, tdi0->inet_laddr.iia_v4addr,
543 			    in_addr_t, tdi->inet_laddr.iia_v4addr);
544 			return FALSE;
545 		}
546 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
547 		    tdi0->inet_raddr.iia_v4addr != tdi->inet_raddr.iia_v4addr) {
548 			DTRACE_SKYWALK2(v4raddr__mismatch,
549 			    in_addr_t, tdi0->inet_raddr.iia_v4addr,
550 			    in_addr_t, tdi->inet_raddr.iia_v4addr);
551 			return FALSE;
552 		}
553 	} else {
554 		ASSERT(tdi0->inet_ipver == IPV6_VERSION);
555 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LADDR) != 0 &&
556 		    !inet_v6addr_cmp(&tdi0->inet_laddr, &tdi->inet_laddr)) {
557 			DTRACE_SKYWALK2(v6laddr__mismatch,
558 			    struct in6_addr *, &tdi0->inet_laddr,
559 			    struct in6_addr *, &tdi->inet_laddr);
560 			return FALSE;
561 		}
562 		if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RADDR) != 0 &&
563 		    !inet_v6addr_cmp(&tdi0->inet_raddr, &tdi->inet_raddr)) {
564 			DTRACE_SKYWALK2(v6raddr__mismatch,
565 			    struct in6_addr *, &tdi0->inet_raddr,
566 			    struct in6_addr *, &tdi->inet_raddr);
567 			return FALSE;
568 		}
569 	}
570 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_LPORT) != 0 &&
571 	    tdi0->inet_lport != tdi->inet_lport) {
572 		DTRACE_SKYWALK2(lport__mismatch,
573 		    uint8_t, tdi0->inet_lport, uint8_t, tdi->inet_lport);
574 		return FALSE;
575 	}
576 	if ((mask & IFNET_TRAFFIC_DESCRIPTOR_INET_RPORT) != 0 &&
577 	    tdi0->inet_rport != tdi->inet_rport) {
578 		DTRACE_SKYWALK2(rport__mismatch,
579 		    uint8_t, tdi0->inet_rport, uint8_t, tdi->inet_rport);
580 		return FALSE;
581 	}
582 	return TRUE;
583 }
584 
585 int
inet_traffic_rule_find(const char * ifname,struct ifnet_traffic_descriptor_common * td,uint32_t flags,struct nxctl_traffic_rule ** ntrp)586 inet_traffic_rule_find(const char *ifname,
587     struct ifnet_traffic_descriptor_common *td, uint32_t flags,
588     struct nxctl_traffic_rule **ntrp)
589 {
590 	struct nxctl_traffic_rule_inet *ntri = NULL;
591 	struct nxctl_traffic_rule_inet_if *rif;
592 	struct ifnet_traffic_descriptor_inet *tdi =
593 	    (struct ifnet_traffic_descriptor_inet *)td;
594 	int i;
595 
596 	if (rs == NULL) {
597 		return ENOENT;
598 	}
599 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
600 		if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) != 0) {
601 			continue;
602 		}
603 		for (i = 0; i < NINETRULEMASKS; i++) {
604 			if ((flags & NTR_FIND_FLAG_EXACT) != 0 &&
605 			    tdi->inet_mask != nxctl_inet_traffic_rule_masks[i]) {
606 				continue;
607 			}
608 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
609 				if (inet_traffic_rule_match(ntri, ifname, flags, tdi)) {
610 					*ntrp = (struct nxctl_traffic_rule *)ntri;
611 					return 0;
612 				}
613 			}
614 		}
615 	}
616 	return ENOENT;
617 }
618 
619 int
inet_traffic_rule_find_by_uuid(uuid_t uuid,struct nxctl_traffic_rule ** ntrp)620 inet_traffic_rule_find_by_uuid(
621 	uuid_t uuid, struct nxctl_traffic_rule **ntrp)
622 {
623 	struct nxctl_traffic_rule_inet *ntri;
624 	struct nxctl_traffic_rule *ntr;
625 	struct nxctl_traffic_rule_inet_if *rif;
626 	int i;
627 
628 	if (rs == NULL) {
629 		return ENOENT;
630 	}
631 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
632 		for (i = 0; i < NINETRULEMASKS; i++) {
633 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
634 				ntr = &ntri->ntri_common;
635 				if (uuid_compare(ntr->ntr_uuid, uuid) == 0) {
636 					*ntrp = ntr;
637 					return 0;
638 				}
639 			}
640 		}
641 	}
642 	return ENOENT;
643 }
644 
645 static void
inet_update_ifnet_traffic_rule_count(const char * ifname,uint32_t count)646 inet_update_ifnet_traffic_rule_count(const char *ifname, uint32_t count)
647 {
648 	struct ifnet *ifp;
649 
650 	ifp = ifunit_ref(ifname);
651 	if (ifp == NULL) {
652 		DTRACE_SKYWALK1(ifname__not__found, char *, ifname);
653 		return;
654 	}
655 	ifnet_update_inet_traffic_rule_count(ifp, count);
656 	ifnet_decr_iorefcnt(ifp);
657 }
658 
659 void
inet_traffic_rule_link(struct nxctl_traffic_rule * ntr)660 inet_traffic_rule_link(struct nxctl_traffic_rule *ntr)
661 {
662 	struct nxctl_traffic_rule_inet_if *rif;
663 	struct nxctl_traffic_rule_inet *ntri =
664 	    (struct nxctl_traffic_rule_inet *)ntr;
665 	struct nxctl_traffic_rule_inet_head *list = NULL;
666 	int i;
667 	char *__null_terminated ntr_ifname = NULL;
668 	char *__null_terminated rii_ifname = NULL;
669 
670 	if (rs == NULL) {
671 		inet_traffic_rule_storage_create();
672 	}
673 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
674 		if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
675 			break;
676 		}
677 	}
678 	if (rif == NULL) {
679 		ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
680 		rif = inet_traffic_rule_if_create(ntr_ifname);
681 		SLIST_INSERT_HEAD(&rs->ris_if_list, rif, rii_link);
682 	}
683 	for (i = 0; i < NINETRULEMASKS; i++) {
684 		if (ntri->ntri_td.inet_mask ==
685 		    nxctl_inet_traffic_rule_masks[i]) {
686 			list = &rif->rii_lists[i];
687 			break;
688 		}
689 	}
690 	retain_traffic_rule(ntr);
691 	ASSERT(list != NULL);
692 	SLIST_INSERT_HEAD(list, ntri, ntri_storage_link);
693 	/* per-interface count */
694 	rif->rii_count++;
695 	rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
696 	inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
697 
698 	/* global count */
699 	rs->ris_count++;
700 }
701 
702 void
inet_traffic_rule_unlink(struct nxctl_traffic_rule * ntr)703 inet_traffic_rule_unlink(struct nxctl_traffic_rule *ntr)
704 {
705 	struct nxctl_traffic_rule_inet_if *rif;
706 	struct nxctl_traffic_rule_inet *ntri =
707 	    (struct nxctl_traffic_rule_inet *)ntr;
708 	struct nxctl_traffic_rule_inet_head *list = NULL;
709 	int i;
710 	char *__null_terminated rii_ifname = NULL;
711 
712 	ASSERT(rs != NULL);
713 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
714 		if (strbufcmp(rif->rii_ifname, ntr->ntr_ifname) == 0) {
715 			break;
716 		}
717 	}
718 	ASSERT(rif != NULL);
719 	for (i = 0; i < NINETRULEMASKS; i++) {
720 		if (ntri->ntri_td.inet_mask ==
721 		    nxctl_inet_traffic_rule_masks[i]) {
722 			list = &rif->rii_lists[i];
723 			break;
724 		}
725 	}
726 	ASSERT(list != NULL);
727 	SLIST_REMOVE(list, ntri, nxctl_traffic_rule_inet, ntri_storage_link);
728 	rif->rii_count--;
729 	rii_ifname = __unsafe_null_terminated_from_indexable(rif->rii_ifname);
730 	inet_update_ifnet_traffic_rule_count(rii_ifname, rif->rii_count);
731 
732 	rs->ris_count--;
733 	release_traffic_rule(ntr);
734 
735 	if (rif->rii_count == 0) {
736 		SLIST_REMOVE(&rs->ris_if_list, rif, nxctl_traffic_rule_inet_if, rii_link);
737 		inet_traffic_rule_if_destroy(rif);
738 	}
739 	if (rs->ris_count == 0) {
740 		inet_traffic_rule_storage_destroy();
741 	}
742 }
743 
744 /*
745  * XXX
746  * This may need additional changes to ensure safety against detach/attach.
747  * This is not an issue for the first consumer of llink interfaces, cellular,
748  * which does not detach.
749  */
750 int
inet_traffic_rule_notify(struct nxctl_traffic_rule * ntr,uint32_t flags)751 inet_traffic_rule_notify(struct nxctl_traffic_rule *ntr, uint32_t flags)
752 {
753 	struct ifnet *ifp;
754 	struct nx_netif *nif;
755 	struct netif_qset *__single qset = NULL;
756 	struct nxctl_traffic_rule_inet *ntri;
757 	int err = 0;
758 	char *__null_terminated ntr_ifname = NULL;
759 
760 	ntr_ifname = __unsafe_null_terminated_from_indexable(ntr->ntr_ifname);
761 	ifp = ifunit_ref(ntr_ifname);
762 	if (ifp == NULL) {
763 		DTRACE_SKYWALK1(ifname__not__found, char *, ntr->ntr_ifname);
764 		err = ENXIO;
765 		goto done;
766 	}
767 	nif = NA(ifp)->nifna_netif;
768 	if (!NX_LLINK_PROV(nif->nif_nx)) {
769 		DTRACE_SKYWALK1(llink__not__enabled, struct ifnet *, ifp);
770 		err = ENOTSUP;
771 		goto done;
772 	}
773 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
774 	qset = nx_netif_find_qset(nif, ntri->ntri_ra.ras_qset_id);
775 	if (qset == NULL || (qset->nqs_flags & NETIF_QSET_FLAG_EXT_INITED) == 0) {
776 		DTRACE_SKYWALK1(qset__not__initialized, struct netif_qset *, qset);
777 		err = ENXIO;
778 		goto done;
779 	}
780 	err = nx_netif_notify_steering_info(nif, qset,
781 	    (struct ifnet_traffic_descriptor_common *)&ntri->ntri_td,
782 	    ((flags & NTR_NOTIFY_FLAG_ADD) != 0));
783 done:
784 	if (qset != NULL) {
785 		nx_netif_qset_release(&qset);
786 	}
787 	if (ifp != NULL) {
788 		ifnet_decr_iorefcnt(ifp);
789 	}
790 	return err;
791 }
792 
793 int
inet_traffic_rule_get_count(const char * ifname,uint32_t * count)794 inet_traffic_rule_get_count(const char *ifname, uint32_t *count)
795 {
796 	struct nxctl_traffic_rule_inet_if *rif;
797 	int err;
798 
799 	if (rs == NULL) {
800 		err = ENOENT;
801 		goto fail;
802 	}
803 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
804 		if (strlcmp(rif->rii_ifname, ifname, sizeof(rif->rii_ifname)) == 0) {
805 			break;
806 		}
807 	}
808 	if (rif == NULL) {
809 		err = ENOENT;
810 		goto fail;
811 	}
812 	*count = rif->rii_count;
813 	return 0;
814 fail:
815 	return err;
816 }
817 
818 int
inet_traffic_rule_create(const char * ifname,struct ifnet_traffic_descriptor_common * td,struct ifnet_traffic_rule_action * ra,uint32_t flags,struct nxctl_traffic_rule ** ntrp)819 inet_traffic_rule_create(
820 	const char *ifname, struct ifnet_traffic_descriptor_common *td,
821 	struct ifnet_traffic_rule_action *ra, uint32_t flags,
822 	struct nxctl_traffic_rule **ntrp)
823 {
824 	struct nxctl_traffic_rule_inet *ntri;
825 	struct nxctl_traffic_rule *ntr;
826 	struct ifnet_traffic_descriptor_inet *tdi;
827 	struct ifnet_traffic_rule_action_steer *ras;
828 
829 	ntri = sk_alloc_type(struct nxctl_traffic_rule_inet,
830 	    Z_WAITOK | Z_NOFAIL, nxctl_traffic_rule_tag);
831 	ntr = &ntri->ntri_common;
832 
833 	ntr->ntrt_type = IFNET_TRAFFIC_DESCRIPTOR_TYPE_INET;
834 	ntr->ntr_flags = flags;
835 	uuid_generate(ntr->ntr_uuid);
836 	os_ref_init(&ntr->ntr_refcnt, NULL);
837 
838 	strlcpy(ntr->ntr_ifname, ifname, sizeof(ntr->ntr_ifname));
839 	proc_selfname(ntr->ntr_procname, sizeof(ntr->ntr_procname));
840 
841 	tdi = __container_of(td, struct ifnet_traffic_descriptor_inet, inet_common);
842 	ras = __container_of(ra, struct ifnet_traffic_rule_action_steer, ras_common);
843 	bcopy(tdi, &ntri->ntri_td, sizeof(ntri->ntri_td));
844 	bcopy(ras, &ntri->ntri_ra, sizeof(ntri->ntri_ra));
845 
846 	*ntrp = ntr;
847 	return 0;
848 }
849 
850 void
inet_traffic_rule_destroy(struct nxctl_traffic_rule * ntr)851 inet_traffic_rule_destroy(struct nxctl_traffic_rule *ntr)
852 {
853 	struct nxctl_traffic_rule_inet *ntri;
854 
855 	ASSERT(os_ref_get_count(&ntr->ntr_refcnt) == 0);
856 	ntri = (struct nxctl_traffic_rule_inet *)ntr;
857 	sk_free_type(struct nxctl_traffic_rule_inet, ntri);
858 }
859 
860 static void
convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet * ntri,struct nxctl_traffic_rule_inet_iocinfo * info)861 convert_ntri_to_iocinfo(struct nxctl_traffic_rule_inet *ntri,
862     struct nxctl_traffic_rule_inet_iocinfo *info)
863 {
864 	struct nxctl_traffic_rule *ntr;
865 	struct nxctl_traffic_rule_generic_iocinfo *ginfo;
866 
867 	bzero(info, sizeof(*info));
868 	ntr = &ntri->ntri_common;
869 	ginfo = &info->tri_common;
870 	static_assert(sizeof(ntr->ntr_procname) == sizeof(ginfo->trg_procname));
871 	static_assert(sizeof(ntr->ntr_ifname) == sizeof(ginfo->trg_ifname));
872 	uuid_copy(ginfo->trg_uuid, ntr->ntr_uuid);
873 	strbufcpy(ginfo->trg_procname, ntr->ntr_procname);
874 	strbufcpy(ginfo->trg_ifname, ntr->ntr_ifname);
875 	bcopy(&ntri->ntri_td, &info->tri_td, sizeof(info->tri_td));
876 	bcopy(&ntri->ntri_ra, &info->tri_ra, sizeof(info->tri_ra));
877 }
878 
879 int
inet_traffic_rule_get_all(uint32_t size,uint32_t * count,user_addr_t uaddr)880 inet_traffic_rule_get_all(uint32_t size,
881     uint32_t *count, user_addr_t uaddr)
882 {
883 	struct nxctl_traffic_rule_inet *ntri = NULL;
884 	struct nxctl_traffic_rule_inet_if *rif;
885 	struct nxctl_traffic_rule_inet_iocinfo info;
886 	int i, err;
887 
888 	if (size != sizeof(info)) {
889 		SK_ERR("size: actual %u, expected %lu", size, sizeof(info));
890 		return EINVAL;
891 	}
892 	if (rs == NULL) {
893 		*count = 0;
894 		return 0;
895 	}
896 	if (*count < rs->ris_count) {
897 		SK_ERR("count: given %d, require: %d", *count, rs->ris_count);
898 		return ENOBUFS;
899 	}
900 	SLIST_FOREACH(rif, &rs->ris_if_list, rii_link) {
901 		for (i = 0; i < NINETRULEMASKS; i++) {
902 			SLIST_FOREACH(ntri, &rif->rii_lists[i], ntri_storage_link) {
903 				convert_ntri_to_iocinfo(ntri, &info);
904 				err = copyout(&info, uaddr, sizeof(info));
905 				if (err != 0) {
906 					SK_ERR("copyout failed: %d", err);
907 					return err;
908 				}
909 				uaddr += sizeof(info);
910 			}
911 		}
912 	}
913 	*count = rs->ris_count;
914 	return 0;
915 }
916