xref: /xnu-12377.41.6/bsd/net/trie_utility.c (revision bbb1b6f9e71b8cdde6e5cd6f4841f207dee3d828)
1 /*
2  * Copyright (c) 2024 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <stdbool.h>
30 #include <sys/types.h>
31 #include <sys/malloc.h>
32 #include <os/base.h>
33 #include <sys/syslog.h>
34 #include <net/sockaddr_utils.h>
35 #include <net/trie_utility.h>
36 
37 int net_trie_log_level = LOG_DEBUG;
38 static os_log_t net_trie_log_handle = NULL;
39 #define NET_TRIE_DEBUG_SEARCH 0
40 
41 #define NET_TRIE_LOG(level, fmt, ...)                                                                                   \
42 do {                                                                                                                    \
43     if (net_trie_log_level >= level && net_trie_log_handle) {                                                           \
44 	if (level == LOG_ERR) {                                                                                         \
45 	    os_log_error(net_trie_log_handle, "NET_TRIE - %s:%d " fmt "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__);     \
46 	} else {                                                                                                        \
47 	    os_log(net_trie_log_handle, "NET_TRIE - %s:%d " fmt "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__);           \
48 	}                                                                                                               \
49     }                                                                                                                   \
50 } while (0)
51 
52 #define TRIE_CHILD_SET(t, i, b, node)                                                                                   \
53     {                                                                                                                   \
54 	if (b >= FIRST_PRINTABLE_ASCII && b <= LAST_PRINTABLE_ASCII) {                                                  \
55 	    (((t)->child_maps + (CHILD_MAP_SIZE * TRIE_NODE(t, i).child_map))[(b - FIRST_PRINTABLE_ASCII)]) = node;     \
56 	}  else {                                                                                                       \
57 	    NET_TRIE_LOG(LOG_ERR, "NETrie - out of printable acsii range <%X>", b);                                     \
58 	}                                                                                                               \
59     }
60 
61 static uint16_t
trie_node_alloc(struct net_trie * trie)62 trie_node_alloc(struct net_trie *trie)
63 {
64 	if (trie->nodes_free_next < trie->nodes_count) {
65 		uint16_t node_idx = trie->nodes_free_next++;
66 		TRIE_NODE(trie, node_idx).child_map = NULL_TRIE_IDX;
67 		return node_idx;
68 	} else {
69 		return NULL_TRIE_IDX;
70 	}
71 }
72 
73 static uint16_t
trie_child_map_alloc(struct net_trie * trie)74 trie_child_map_alloc(struct net_trie *trie)
75 {
76 	if (trie->child_maps_free_next < trie->child_maps_count) {
77 		return trie->child_maps_free_next++;
78 	} else {
79 		return NULL_TRIE_IDX;
80 	}
81 }
82 
83 static uint16_t
trie_bytes_move(struct net_trie * trie,uint16_t bytes_idx,size_t bytes_size)84 trie_bytes_move(struct net_trie *trie, uint16_t bytes_idx, size_t bytes_size)
85 {
86 	uint16_t start = trie->bytes_free_next;
87 	if (start + bytes_size <= trie->bytes_count) {
88 		if (start != bytes_idx) {
89 			memmove(&TRIE_BYTE(trie, start), &TRIE_BYTE(trie, bytes_idx), bytes_size);
90 		}
91 		trie->bytes_free_next += bytes_size;
92 		return start;
93 	} else {
94 		return NULL_TRIE_IDX;
95 	}
96 }
97 
98 static boolean_t
net_trie_has_high_ascii(const uint8_t * __sized_by (string_length)string,size_t string_length)99 net_trie_has_high_ascii(const uint8_t * __sized_by(string_length)string, size_t string_length)
100 {
101 	for (int i = 0; i < (int)string_length; i++) {
102 		if (HIGH_ASCII(string[i])) {
103 			return true;
104 		}
105 	}
106 	return false;
107 }
108 
109 boolean_t
net_trie_init(struct net_trie * new_trie,size_t prefix_count,size_t leaf_count,size_t bytes_count)110 net_trie_init(struct net_trie *new_trie, size_t prefix_count, size_t leaf_count, size_t bytes_count)
111 {
112 	size_t bytes_mem_size;
113 	size_t child_maps_mem_size;
114 	size_t nodes_mem_size;
115 	size_t trie_memory_size = 0;
116 	size_t nodes_count = 0;
117 	size_t maps_count = 0;
118 	int data_memory_offset = 0;
119 
120 	if (new_trie == NULL) {
121 		return false;
122 	}
123 
124 	if (net_trie_log_handle == NULL) {
125 		net_trie_log_handle = os_log_create("com.apple.xnu.net.trie", "net_trie");
126 	}
127 
128 	memset(new_trie, 0, sizeof(struct net_trie));
129 
130 	if (new_trie == NULL || prefix_count <= 0 || leaf_count <= 0 || bytes_count <= 0) {
131 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - null trie, no prefix/leaf count or no byte count", __FUNCTION__);
132 		return false;
133 	}
134 
135 	if (os_add3_overflow(prefix_count, leaf_count, 1, &nodes_count)) { /* + 1 for the root node */
136 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Overflow while computing the number of nodes", __FUNCTION__);
137 		return false;
138 	}
139 
140 	if (os_add_overflow(prefix_count, 1, &maps_count)) { /* + 1 for the root node */
141 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Overflow while computing the number of maps", __FUNCTION__);
142 		return false;
143 	}
144 
145 	if (bytes_count > UINT16_MAX || nodes_count > UINT16_MAX || maps_count > UINT16_MAX) {
146 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Invalid bytes count (%lu), nodes count (%lu) or maps count (%lu)", __FUNCTION__, bytes_count, nodes_count, maps_count);
147 		return false;
148 	}
149 
150 	if (os_mul_overflow(sizeof(*new_trie->nodes), (size_t)nodes_count, &nodes_mem_size) ||
151 	    os_mul3_overflow(sizeof(*new_trie->child_maps), CHILD_MAP_SIZE, (size_t)maps_count, &child_maps_mem_size) ||
152 	    os_mul_overflow(sizeof(*new_trie->bytes), (size_t)bytes_count, &bytes_mem_size) ||
153 	    os_add3_overflow(nodes_mem_size, child_maps_mem_size, bytes_mem_size, &trie_memory_size)) {
154 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Overflow while computing trie memory sizes", __FUNCTION__);
155 		return false;
156 	}
157 
158 	if (trie_memory_size > MAX_TRIE_MEMORY) {
159 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Trie memory size (%lu) is too big (maximum is %u)", __FUNCTION__, trie_memory_size, MAX_TRIE_MEMORY);
160 		return false;
161 	}
162 
163 	NET_TRIE_LOG(LOG_DEBUG, "%s: NET_TRIE - initializing (Nodes count = %lu, child maps count = %lu, bytes_count = %lu, total memory size %lu)", __FUNCTION__, nodes_count, maps_count, bytes_count, trie_memory_size);
164 
165 	void *memory = (u_int8_t *)kalloc_data(trie_memory_size, Z_WAITOK | Z_ZERO);
166 	if (memory == NULL) {
167 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Failed to allocate %lu bytes of memory for the trie", __FUNCTION__, trie_memory_size);
168 		return false;
169 	}
170 	new_trie->memory = memory;
171 	new_trie->trie_memory_size = trie_memory_size;
172 
173 	new_trie->magic = NET_TRIE_MAGIC;
174 	new_trie->version = NET_TRIE_FORMAT_VERSION;
175 
176 	new_trie->nodes_mem_size = nodes_mem_size;
177 	new_trie->child_maps_mem_size = child_maps_mem_size;
178 	new_trie->bytes_mem_size = bytes_mem_size;
179 
180 	/* Initialize the free lists */
181 	uint8_t *data_memory = (uint8_t *)new_trie->memory + data_memory_offset;
182 	new_trie->nodes = (struct net_trie_node *)(void *)(data_memory);
183 	new_trie->nodes_count = (uint16_t)nodes_count;
184 	new_trie->nodes_free_next = 0;
185 	memset(new_trie->nodes, 0, nodes_mem_size);
186 
187 	new_trie->child_maps = (uint16_t *)(void *)(data_memory + nodes_mem_size);
188 	new_trie->child_maps_count = (uint16_t)maps_count;
189 	new_trie->child_maps_free_next = 0;
190 	memset(new_trie->child_maps, 0xff, child_maps_mem_size);
191 
192 	new_trie->bytes = (uint8_t *)(void *)(data_memory + nodes_mem_size + child_maps_mem_size);
193 	new_trie->bytes_count = (uint16_t)bytes_count;
194 	new_trie->bytes_free_next = 0;
195 	memset(new_trie->bytes, 0, bytes_mem_size);
196 
197 	/* The root is an empty node */
198 	new_trie->root = trie_node_alloc(new_trie);
199 
200 	return true;
201 }
202 
203 boolean_t
net_trie_init_with_mem(struct net_trie * new_trie,uint8_t * __sized_by (trie_memory_size)memory,size_t trie_memory_size,size_t nodes_mem_size,size_t child_maps_mem_size,size_t bytes_mem_size,uint16_t nodes_count,uint16_t child_maps_count,uint16_t bytes_count)204 net_trie_init_with_mem(struct net_trie *new_trie, uint8_t * __sized_by(trie_memory_size) memory, size_t trie_memory_size,
205     size_t nodes_mem_size, size_t child_maps_mem_size, size_t bytes_mem_size,
206     uint16_t nodes_count, uint16_t child_maps_count, uint16_t bytes_count)
207 {
208 	size_t test_trie_memory_size = 0;
209 	size_t test_nodes_mem_size = 0;
210 	size_t test_child_maps_mem_size = 0;
211 	size_t test_bytes_mem_size = 0;
212 
213 	if (new_trie == NULL || memory == NULL) {
214 		return false;
215 	}
216 
217 	if (net_trie_log_handle == NULL) {
218 		net_trie_log_handle = os_log_create("com.apple.xnu.net.trie", "net_trie");
219 	}
220 
221 	// Validate all passed in sizes and counts:
222 	if (os_add3_overflow(nodes_mem_size, child_maps_mem_size, bytes_mem_size, &test_trie_memory_size) ||
223 	    os_mul_overflow(sizeof(*new_trie->nodes), (size_t)nodes_count, &test_nodes_mem_size) ||
224 	    os_mul3_overflow(sizeof(*new_trie->child_maps), CHILD_MAP_SIZE, (size_t)child_maps_count, &test_child_maps_mem_size) ||
225 	    os_mul_overflow(sizeof(*new_trie->bytes), (size_t)bytes_count, &test_bytes_mem_size)) {
226 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Overflow while validating trie memory sizes", __FUNCTION__);
227 		return false;
228 	}
229 	if (test_trie_memory_size != trie_memory_size) {
230 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - passed in mem sizes (nodes %zu maps %zu bytes %zu) not equal to total mem %zu",
231 		    __FUNCTION__, nodes_mem_size, child_maps_mem_size, bytes_mem_size, trie_memory_size);
232 		return false;
233 	}
234 	if (test_nodes_mem_size != nodes_mem_size) {
235 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - passed in nodes_count %d not valid", __FUNCTION__, nodes_count);
236 		return false;
237 	}
238 	if (test_child_maps_mem_size != child_maps_mem_size) {
239 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - passed in maps_count %d not valid", __FUNCTION__, child_maps_count);
240 		return false;
241 	}
242 	if (test_bytes_mem_size != bytes_mem_size) {
243 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - passed in bytes_count %d not valid", __FUNCTION__, bytes_count);
244 		return false;
245 	}
246 
247 	memset(new_trie, 0, sizeof(struct net_trie));
248 
249 	new_trie->memory = memory;
250 	new_trie->trie_memory_size = trie_memory_size;
251 	NET_TRIE_LOG(LOG_DEBUG, "%s: NET_TRIE - initialized with malloc %zu", __FUNCTION__, trie_memory_size);
252 
253 	new_trie->magic = NET_TRIE_MAGIC;
254 	new_trie->version = NET_TRIE_FORMAT_VERSION;
255 
256 	new_trie->nodes_mem_size = nodes_mem_size;
257 	new_trie->child_maps_mem_size = child_maps_mem_size;
258 	new_trie->bytes_mem_size = bytes_mem_size;
259 
260 	uint8_t *data_memory = (uint8_t *)new_trie->memory;
261 	new_trie->nodes = (struct net_trie_node *)(void *)(data_memory);
262 	new_trie->nodes_count = (uint16_t)nodes_count;
263 
264 	new_trie->child_maps = (uint16_t *)(void *)(data_memory + nodes_mem_size);
265 	new_trie->child_maps_count = (uint16_t)child_maps_count;
266 
267 	new_trie->bytes = (uint8_t *)(void *)(data_memory + nodes_mem_size + child_maps_mem_size);
268 	new_trie->bytes_count = (uint16_t)bytes_count;
269 
270 	/* The root points to the first node */
271 	new_trie->root = 0;
272 
273 	NET_TRIE_LOG(LOG_DEBUG, "%s: NET_TRIE - initialized - mem %X (size %zu) nodes %X (size %zu count %d) maps %X (size %zu count %d) bytes %X (size %zu count %d)",
274 	    __FUNCTION__,
275 	    (unsigned int)new_trie->memory, new_trie->trie_memory_size,
276 	    (unsigned int)new_trie->nodes, new_trie->nodes_mem_size, new_trie->nodes_count,
277 	    (unsigned int)new_trie->child_maps, new_trie->child_maps_mem_size, new_trie->child_maps_count,
278 	    (unsigned int)new_trie->bytes, new_trie->bytes_mem_size, new_trie->bytes_count);
279 
280 	return true;
281 }
282 
283 void
net_trie_free(struct net_trie * new_trie)284 net_trie_free(struct net_trie *new_trie)
285 {
286 	if (new_trie == NULL || new_trie->memory == NULL) {
287 		return;
288 	}
289 	kfree_data_sized_by(new_trie->memory, new_trie->trie_memory_size);
290 	memset(new_trie, 0, sizeof(struct net_trie));
291 }
292 
293 uint16_t
net_trie_insert(struct net_trie * trie,const uint8_t * __sized_by (string_length)string,size_t string_length,const uint8_t * __sized_by (metadata_length)metadata,size_t metadata_length,boolean_t reverse)294 net_trie_insert(struct net_trie *trie,
295     const uint8_t * __sized_by(string_length) string, size_t string_length,
296     const uint8_t * __sized_by(metadata_length) metadata, size_t metadata_length,
297     boolean_t reverse)
298 {
299 	if (trie->memory == NULL || string == NULL || string_length == 0) {
300 		return NULL_TRIE_IDX;
301 	}
302 
303 	if (string_length > UINT16_MAX || trie->bytes_free_next + (uint16_t)string_length > trie->bytes_count) {
304 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - failed insert - out of allocated memory", __FUNCTION__);
305 		return NULL_TRIE_IDX;
306 	}
307 
308 	if (net_trie_has_high_ascii(string, string_length)) {
309 		NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - failed insert - non-printable ASCII not supported", __FUNCTION__);
310 		return NULL_TRIE_IDX;
311 	}
312 
313 	char *byte = (char *)&TRIE_BYTE(trie, trie->bytes_free_next);
314 
315 	if (reverse) {
316 		for (size_t i = 0, j = string_length - 1; i < string_length; i++, j--) {
317 			byte[i] = string[j];
318 		}
319 	} else {
320 		memcpy(byte, string, string_length);
321 	}
322 
323 	uint16_t current = trie->root;
324 	uint16_t child = trie->root;
325 	uint16_t string_end = trie->bytes_free_next + (uint16_t)string_length;
326 	uint16_t string_idx = trie->bytes_free_next;
327 	uint16_t string_remainder = (uint16_t)string_length;
328 
329 	while (child != NULL_TRIE_IDX) {
330 		uint16_t parent = current;
331 		uint16_t node_idx;
332 		uint16_t current_end;
333 
334 		current = child;
335 		child = NULL_TRIE_IDX;
336 
337 		current_end = TRIE_NODE(trie, current).start + TRIE_NODE(trie, current).length;
338 
339 		for (node_idx = TRIE_NODE(trie, current).start;
340 		    node_idx < current_end &&
341 		    string_idx < string_end &&
342 		    TRIE_BYTE(trie, node_idx) == TRIE_BYTE(trie, string_idx);
343 		    node_idx++, string_idx++) {
344 			;
345 		}
346 
347 		string_remainder = string_end - string_idx;
348 
349 		if (node_idx < (TRIE_NODE(trie, current).start + TRIE_NODE(trie, current).length)) {
350 			/*
351 			 * We did not reach the end of the current node's string.
352 			 * We need to split the current node into two:
353 			 *   1. A new node that contains the prefix of the node that matches
354 			 *      the prefix of the string being inserted.
355 			 *   2. The current node modified to point to the remainder
356 			 *      of the current node's string.
357 			 */
358 			uint16_t prefix = trie_node_alloc(trie);
359 			if (prefix == NULL_TRIE_IDX) {
360 				NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Ran out of trie nodes while splitting an existing node", __FUNCTION__);
361 				return NULL_TRIE_IDX;
362 			}
363 
364 			/*
365 			 * Prefix points to the portion of the current nodes's string that has matched
366 			 * the input string thus far.
367 			 */
368 			TRIE_NODE(trie, prefix).start = TRIE_NODE(trie, current).start;
369 			TRIE_NODE(trie, prefix).length = (node_idx - TRIE_NODE(trie, current).start);
370 			if (string_remainder == 0) {
371 				TRIE_NODE(trie, prefix).is_leaf = true;
372 
373 				/* Store the metadata */
374 				if (metadata && metadata_length > 0) {
375 					char *byte_ptr = (char *)&TRIE_BYTE(trie, trie->bytes_free_next);
376 					memcpy(byte_ptr, metadata, metadata_length);
377 					TRIE_NODE(trie, prefix).metadata = trie_bytes_move(trie, trie->bytes_free_next, metadata_length);
378 					TRIE_NODE(trie, prefix).metadata_length = (uint16_t)metadata_length;
379 				}
380 			}
381 
382 			/*
383 			 * Prefix has the current node as the child corresponding to the first byte
384 			 * after the split.
385 			 */
386 			TRIE_NODE(trie, prefix).child_map = trie_child_map_alloc(trie);
387 			if (TRIE_NODE(trie, prefix).child_map == NULL_TRIE_IDX) {
388 				NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Ran out of child maps while splitting an existing node", __FUNCTION__);
389 				return NULL_TRIE_IDX;
390 			}
391 			TRIE_CHILD_SET(trie, prefix, TRIE_BYTE(trie, node_idx), current);
392 
393 			/* Parent has the prefix as the child correspoding to the first byte in the prefix */
394 			TRIE_CHILD_SET(trie, parent, TRIE_BYTE(trie, TRIE_NODE(trie, prefix).start), prefix);
395 
396 			/* Current node is adjusted to point to the remainder */
397 			TRIE_NODE(trie, current).start = node_idx;
398 			TRIE_NODE(trie, current).length -= TRIE_NODE(trie, prefix).length;
399 
400 			/* We want to insert the new leaf (if any) as a child of the prefix */
401 			current = prefix;
402 		}
403 
404 		if (string_remainder > 0) {
405 			/*
406 			 * We still have bytes in the string that have not been matched yet.
407 			 * If the current node has children, iterate to the child corresponding
408 			 * to the next byte in the string.
409 			 */
410 			if (TRIE_NODE(trie, current).child_map != NULL_TRIE_IDX) {
411 				child = TRIE_CHILD_GET(trie, current, TRIE_BYTE(trie, string_idx));
412 			}
413 		}
414 	} /* while (child != NULL_TRIE_IDX) */
415 
416 	if (string_remainder > 0) {
417 		/* Add a new leaf containing the remainder of the string */
418 		uint16_t leaf = trie_node_alloc(trie);
419 		if (leaf == NULL_TRIE_IDX) {
420 			NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Ran out of trie nodes while inserting a new leaf", __FUNCTION__);
421 			return NULL_TRIE_IDX;
422 		}
423 
424 		TRIE_NODE(trie, leaf).start = trie_bytes_move(trie, string_idx, string_remainder);
425 		if (TRIE_NODE(trie, leaf).start == NULL_TRIE_IDX) {
426 			NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Ran out of bytes while inserting a new leaf", __FUNCTION__);
427 			return NULL_TRIE_IDX;
428 		}
429 		TRIE_NODE(trie, leaf).length = string_remainder;
430 		TRIE_NODE(trie, leaf).is_leaf = true;
431 
432 		/* Store the metadata */
433 		if (metadata && metadata_length > 0) {
434 			char *byte_ptr = (char *)&TRIE_BYTE(trie, trie->bytes_free_next);
435 			memcpy(byte_ptr, metadata, metadata_length);
436 			TRIE_NODE(trie, leaf).metadata = trie_bytes_move(trie, trie->bytes_free_next, metadata_length);
437 			TRIE_NODE(trie, leaf).metadata_length = (uint16_t)metadata_length;
438 		}
439 
440 		/* Set the new leaf as the child of the current node */
441 		if (TRIE_NODE(trie, current).child_map == NULL_TRIE_IDX) {
442 			TRIE_NODE(trie, current).child_map = trie_child_map_alloc(trie);
443 			if (TRIE_NODE(trie, current).child_map == NULL_TRIE_IDX) {
444 				NET_TRIE_LOG(LOG_ERR, "%s: NET_TRIE - Ran out of child maps while inserting a new leaf", __FUNCTION__);
445 				return NULL_TRIE_IDX;
446 			}
447 		}
448 		TRIE_CHILD_SET(trie, current, TRIE_BYTE(trie, TRIE_NODE(trie, leaf).start), leaf);
449 		current = leaf;
450 	} /* else duplicate or this string is a prefix of one of the existing strings */
451 
452 	return current;
453 }
454 
455 uint16_t
net_trie_search(struct net_trie * trie,const uint8_t * __sized_by (string_length)string,size_t string_length,const uint8_t * __sized_by (* metadata_length)* metadata,size_t * metadata_length,boolean_t reverse,boolean_t partial_match_allowed,char partial_match_terminator,boolean_t * high_ascii_detected,check_metadata_func check_metadata)456 net_trie_search(struct net_trie *trie,
457     const uint8_t * __sized_by(string_length) string, size_t string_length,
458     const uint8_t * __sized_by(*metadata_length) * metadata, size_t *metadata_length,
459     boolean_t reverse, boolean_t partial_match_allowed, char partial_match_terminator,
460     boolean_t *high_ascii_detected, check_metadata_func check_metadata)
461 {
462 	if (trie->memory == NULL || string == NULL || string_length == 0) {
463 		return NULL_TRIE_IDX;
464 	}
465 
466 	uint16_t last_matched = NULL_TRIE_IDX;
467 	uint16_t current = trie->root;
468 	int16_t string_idx = reverse ? (int16_t)(string_length - 1) : 0;
469 
470 #if NET_TRIE_DEBUG_SEARCH
471 	NET_TRIE_LOG(LOG_DEBUG, "NET_TRIE - search %s len %zu reverse %d", string, string_length, reverse);
472 #endif
473 
474 	while (current != NULL_TRIE_IDX) {
475 		uint16_t next = NULL_TRIE_IDX;
476 		uint16_t node_end = TRIE_NODE(trie, current).start + TRIE_NODE(trie, current).length;
477 		uint16_t node_idx;
478 
479 		if (reverse) {
480 			for (node_idx = TRIE_NODE(trie, current).start;
481 			    node_idx < node_end && string_idx >= 0 && string[string_idx] == TRIE_BYTE(trie, node_idx);
482 			    node_idx++, string_idx--) {
483 #if NET_TRIE_DEBUG_SEARCH
484 				NET_TRIE_LOG(LOG_DEBUG, "%c", string[string_idx]);
485 #endif
486 				;
487 			}
488 		} else {
489 			for (node_idx = TRIE_NODE(trie, current).start;
490 			    node_idx < node_end && string_idx < (int16_t)string_length && string[string_idx] == TRIE_BYTE(trie, node_idx);
491 			    node_idx++, string_idx++) {
492 #if NET_TRIE_DEBUG_SEARCH
493 				NET_TRIE_LOG(LOG_DEBUG, "%c", string[string_idx]);
494 #endif
495 				;
496 			}
497 		}
498 
499 		// High Ascii detection -
500 		// Any char matching the node string are not high Ascii.  Only need to check mismatched char.
501 		if (string_idx >= 0 && string_idx < (int16_t)string_length && HIGH_ASCII(string[string_idx])) {
502 			if (high_ascii_detected) {
503 				*high_ascii_detected = true;
504 			}
505 			return NULL_TRIE_IDX;
506 		}
507 
508 #if NET_TRIE_DEBUG_SEARCH
509 		NET_TRIE_LOG(LOG_DEBUG, "NET_TRIE - node_idx %d node_end %d", node_idx, node_end);
510 #endif
511 
512 		if (node_idx == node_end) {
513 			boolean_t exact_matched = ((reverse && string_idx < 0) || (string_idx == (int16_t)string_length));
514 			boolean_t partial_matched = (!exact_matched && partial_match_allowed && (string[string_idx] == partial_match_terminator));
515 
516 #if NET_TRIE_DEBUG_SEARCH
517 			NET_TRIE_LOG(LOG_DEBUG, "NET_TRIE - reverse %d string_idx %d byte %d leaf %d (exact_matched %d partial_matched %d)",
518 			    reverse, string_idx, string_idx >= 0 && string_idx < (int16_t)string_length ? string[string_idx] : 888,
519 			    TRIE_NODE(trie, current).is_leaf, exact_matched, partial_matched);
520 #endif
521 
522 			if (TRIE_NODE(trie, current).is_leaf == true) {
523 				uint16_t metadata_idex = TRIE_NODE(trie, current).metadata;
524 				const uint8_t *data = (metadata_idex > 0) ? &TRIE_BYTE(trie, metadata_idex) : NULL;
525 				size_t length = TRIE_NODE(trie, current).metadata_length;
526 
527 				// Consider a match only if the metadata qualifies
528 				if (check_metadata == NULL || check_metadata(data, length)) {
529 					if (exact_matched) {
530 						// Provide access of leaf metadata to caller
531 						if (metadata && metadata_length) {
532 							if (data != NULL && length > 0) {
533 								*metadata = data;
534 								*metadata_length = length;
535 							}
536 						}
537 						return current; /* Got an exact match */
538 					} else if (partial_matched) {
539 						// Remember the last partial match but continue to try exact match
540 						last_matched = current;
541 					}
542 				}
543 			}
544 			if (string_idx >= 0 && string_idx < (int16_t)string_length &&
545 			    TRIE_NODE(trie, current).child_map != NULL_TRIE_IDX) {
546 				next = TRIE_CHILD_GET(trie, current, string[string_idx]);
547 			}
548 		}
549 		current = next;
550 	}
551 
552 	// Couldn't find an exact match, but there is a closest partial match
553 	if (last_matched != NULL_TRIE_IDX) {
554 		// Provide access of leaf metadata to caller
555 		if (metadata && metadata_length) {
556 			uint16_t metadata_idex = TRIE_NODE(trie, last_matched).metadata;
557 			const uint8_t *data = (metadata_idex > 0) ? &TRIE_BYTE(trie, metadata_idex) : NULL;
558 			size_t length = TRIE_NODE(trie, last_matched).metadata_length;
559 			if (data != NULL && length > 0) {
560 				*metadata = data;
561 				*metadata_length = length;
562 			}
563 		}
564 		return last_matched;
565 	}
566 
567 	// High Ascii detection -
568 	// Failed to match entire/partial string, complete the high Ascii check
569 	if (high_ascii_detected) {
570 		if (reverse) {
571 			for (; string_idx >= 0; string_idx--) {
572 				if (HIGH_ASCII(string[string_idx])) {
573 					*high_ascii_detected = true;
574 					break;
575 				}
576 			}
577 		} else {
578 			for (; string_idx < (int16_t)string_length; string_idx++) {
579 				if (HIGH_ASCII(string[string_idx])) {
580 					*high_ascii_detected = true;
581 					break;
582 				}
583 			}
584 		}
585 	}
586 
587 	return NULL_TRIE_IDX;
588 }
589