xref: /xnu-12377.41.6/tests/arm_matrix_sme.c (revision bbb1b6f9e71b8cdde6e5cd6f4841f207dee3d828)
1 /*
2  * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <mach/thread_act.h>
30 #include <stdint.h>
31 #include <stdlib.h>
32 #include <sys/sysctl.h>
33 
34 #include "arm_matrix.h"
35 
36 const static unsigned int SME_Z_VECTORS = 32;
37 const static unsigned int SME_P_VECTORS = 16;
38 
39 static unsigned int
sme_version(void)40 sme_version(void)
41 {
42 	static unsigned int ret = 0;
43 	static bool already_read = false;
44 
45 	if (!already_read) {
46 		size_t size = sizeof(unsigned int);
47 		unsigned int feat_sme, feat_sme2;
48 		sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
49 		sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);
50 
51 		if (feat_sme2) {
52 			ret = 2;
53 		} else if (feat_sme) {
54 			ret = 1;
55 		} else {
56 			ret = 0;
57 		}
58 
59 		already_read = true;
60 	}
61 
62 	return ret;
63 }
64 
65 static uint16_t
arm_sme_svl_b(void)66 arm_sme_svl_b(void)
67 {
68 	uint64_t ret = 0;
69 	asm volatile (
70                 "rdsvl	%[ret], #1"
71                 : [ret] "=r"(ret)
72         );
73 	return (uint16_t)ret;
74 }
75 
76 static size_t
sme_za_size(void)77 sme_za_size(void)
78 {
79 	return arm_sme_svl_b() * arm_sme_svl_b();
80 }
81 
82 static size_t
sme_z_size(void)83 sme_z_size(void)
84 {
85 	return arm_sme_svl_b() * SME_Z_VECTORS;
86 }
87 
88 static size_t
sme_p_size(void)89 sme_p_size(void)
90 {
91 	return arm_sme_svl_b() * SME_P_VECTORS / 8;
92 }
93 
94 static size_t
sme_zt0_size(void)95 sme_zt0_size(void)
96 {
97 	if (sme_version() >= 2) {
98 		return 64;
99 	} else {
100 		return 0;
101 	}
102 }
103 
104 static size_t
sme_tpidr2_size(void)105 sme_tpidr2_size(void)
106 {
107 	return sizeof(uint64_t);
108 }
109 
110 static inline uint8_t *
sme_za(void * addr)111 sme_za(void *addr)
112 {
113 	return addr;
114 }
115 
116 static inline const uint8_t *
const_sme_za(const void * addr)117 const_sme_za(const void *addr)
118 {
119 	return addr;
120 }
121 
122 static inline uint8_t *
sme_zt0(void * addr)123 sme_zt0(void *addr)
124 {
125 	return sme_za(addr) + sme_za_size();
126 }
127 
128 static inline const uint8_t *
const_sme_zt0(const void * addr)129 const_sme_zt0(const void *addr)
130 {
131 	return const_sme_za(addr) + sme_za_size();
132 }
133 
134 static inline uint8_t *
sme_tpidr2_el0(void * addr)135 sme_tpidr2_el0(void *addr)
136 {
137 	return sme_zt0(addr) + sme_zt0_size();
138 }
139 
140 static inline const uint8_t *
const_sme_tpidr2_el0(const void * addr)141 const_sme_tpidr2_el0(const void *addr)
142 {
143 	return const_sme_zt0(addr) + sme_zt0_size();
144 }
145 
146 static inline uint8_t *
sme_z(void * addr)147 sme_z(void *addr)
148 {
149 	return sme_tpidr2_el0(addr) + sizeof(uint64_t);
150 }
151 
152 static inline const uint8_t *
const_sme_z(const void * addr)153 const_sme_z(const void *addr)
154 {
155 	return const_sme_tpidr2_el0(addr) + sizeof(uint64_t);
156 }
157 
158 static inline uint8_t *
sme_p(void * addr)159 sme_p(void *addr)
160 {
161 	return sme_z(addr) + sme_z_size();
162 }
163 
164 static inline const uint8_t *
const_sme_p(const void * addr)165 const_sme_p(const void *addr)
166 {
167 	return const_sme_z(addr) + sme_z_size();
168 }
169 
170 static size_t
sme_data_size(void)171 sme_data_size(void)
172 {
173 	return sme_za_size() + sme_zt0_size() + sme_tpidr2_size() + sme_z_size() + sme_p_size();
174 }
175 
176 static size_t
sme_za_data_size(void)177 sme_za_data_size(void)
178 {
179 	return sme_za_size() + sme_zt0_size() + sme_tpidr2_size();
180 }
181 
182 static inline void
set_sme_tpidr2_el0(void * addr,uint64_t val)183 set_sme_tpidr2_el0(void *addr, uint64_t val)
184 {
185 	uint64_t *ptr = (uint64_t *)(sme_tpidr2_el0(addr));
186 	*ptr = val;
187 }
188 
189 static inline uint64_t
get_sme_tpidr2_el0(const void * addr)190 get_sme_tpidr2_el0(const void *addr)
191 {
192 	const uint64_t *ptr = (const uint64_t *)(const_sme_tpidr2_el0(addr));
193 	return *ptr;
194 }
195 
196 static void *
sme_alloc_data(void)197 sme_alloc_data(void)
198 {
199 	return malloc(sme_data_size());
200 }
201 
202 static void *
sme_za_alloc_data(void)203 sme_za_alloc_data(void)
204 {
205 	return malloc(sme_za_data_size());
206 }
207 
208 static bool
sme_is_available(void)209 sme_is_available(void)
210 {
211 	return sme_version() > 0;
212 }
213 
214 static void
sme_start(void)215 sme_start(void)
216 {
217 	asm volatile ("smstart");
218 }
219 
220 static void
sme_za_start(void)221 sme_za_start(void)
222 {
223 	asm volatile ("smstart za");
224 }
225 
226 static void
sme_stop(void)227 sme_stop(void)
228 {
229 	asm volatile ("smstop");
230 }
231 
232 static void
sme_za_stop(void)233 sme_za_stop(void)
234 {
235 	asm volatile ("smstop za");
236 }
237 
238 static void
sme_load_one_vector(const void * addr)239 sme_load_one_vector(const void *addr)
240 {
241 	asm volatile (
242                 "mov    w12, #0"                "\n"
243                 "ldr    za[w12, #0], [%[addr]]" "\n"
244                 :
245                 : [addr] "r"(addr)
246                 : "w12"
247         );
248 }
249 
250 static void
sme_za_load_data(const void * addr)251 sme_za_load_data(const void *addr)
252 {
253 	const uint8_t *za = const_sme_za(addr);
254 	uint16_t svl_b = arm_sme_svl_b();
255 
256 	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
257 		asm volatile (
258                         "ldr    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
259                         "ldr    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
260                         "ldr    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
261                         "ldr    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
262                         "ldr    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
263                         "ldr    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
264                         "ldr    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
265                         "ldr    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
266                         "ldr    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
267                         "ldr    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
268                         "ldr    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
269                         "ldr    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
270                         "ldr    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
271                         "ldr    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
272                         "ldr    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
273                         "ldr    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
274                         :
275                         : [i] "r"(i),
276                           [addr] "r"(za + (i * svl_b))
277                 );
278 	}
279 
280 	if (sme_zt0_size()) {
281 		const uint8_t *zt0 = const_sme_zt0(addr);
282 		asm volatile (
283                         "ldr	zt0, [%[zt0]]"
284                         :
285                         : [zt0] "r"(zt0)
286                 );
287 	}
288 
289 	__builtin_arm_wsr64("TPIDR2_EL0", get_sme_tpidr2_el0(addr));
290 }
291 
292 static void
sme_load_data(const void * addr)293 sme_load_data(const void *addr)
294 {
295 	const uint8_t *z = const_sme_z(addr);
296 	const uint8_t *p = const_sme_p(addr);
297 
298 	sme_za_load_data(addr);
299 
300 	asm volatile (
301                 "ldr    z0, [%[z],   #0, mul vl]"        "\n"
302                 "ldr    z1, [%[z],   #1, mul vl]"        "\n"
303                 "ldr    z2, [%[z],   #2, mul vl]"        "\n"
304                 "ldr    z3, [%[z],   #3, mul vl]"        "\n"
305                 "ldr    z4, [%[z],   #4, mul vl]"        "\n"
306                 "ldr    z5, [%[z],   #5, mul vl]"        "\n"
307                 "ldr    z6, [%[z],   #6, mul vl]"        "\n"
308                 "ldr    z7, [%[z],   #7, mul vl]"        "\n"
309                 "ldr    z8, [%[z],   #8, mul vl]"        "\n"
310                 "ldr    z9, [%[z],   #9, mul vl]"        "\n"
311                 "ldr   z10, [%[z],  #10, mul vl]"        "\n"
312                 "ldr   z11, [%[z],  #11, mul vl]"        "\n"
313                 "ldr   z12, [%[z],  #12, mul vl]"        "\n"
314                 "ldr   z13, [%[z],  #13, mul vl]"        "\n"
315                 "ldr   z14, [%[z],  #14, mul vl]"        "\n"
316                 "ldr   z15, [%[z],  #15, mul vl]"        "\n"
317                 "ldr   z16, [%[z],  #16, mul vl]"        "\n"
318                 "ldr   z17, [%[z],  #17, mul vl]"        "\n"
319                 "ldr   z18, [%[z],  #18, mul vl]"        "\n"
320                 "ldr   z19, [%[z],  #19, mul vl]"        "\n"
321                 "ldr   z20, [%[z],  #20, mul vl]"        "\n"
322                 "ldr   z21, [%[z],  #21, mul vl]"        "\n"
323                 "ldr   z22, [%[z],  #22, mul vl]"        "\n"
324                 "ldr   z23, [%[z],  #23, mul vl]"        "\n"
325                 "ldr   z24, [%[z],  #24, mul vl]"        "\n"
326                 "ldr   z25, [%[z],  #25, mul vl]"        "\n"
327                 "ldr   z26, [%[z],  #26, mul vl]"        "\n"
328                 "ldr   z27, [%[z],  #27, mul vl]"        "\n"
329                 "ldr   z28, [%[z],  #28, mul vl]"        "\n"
330                 "ldr   z29, [%[z],  #29, mul vl]"        "\n"
331                 "ldr   z30, [%[z],  #30, mul vl]"        "\n"
332                 "ldr   z31, [%[z],  #31, mul vl]"        "\n"
333                 :
334                 : [z] "r"(z)
335         );
336 
337 	asm volatile (
338                 "ldr     p0, [%[p],  #0, mul vl]"        "\n"
339                 "ldr     p1, [%[p],  #1, mul vl]"        "\n"
340                 "ldr     p2, [%[p],  #2, mul vl]"        "\n"
341                 "ldr     p3, [%[p],  #3, mul vl]"        "\n"
342                 "ldr     p4, [%[p],  #4, mul vl]"        "\n"
343                 "ldr     p5, [%[p],  #5, mul vl]"        "\n"
344                 "ldr     p6, [%[p],  #6, mul vl]"        "\n"
345                 "ldr     p7, [%[p],  #7, mul vl]"        "\n"
346                 "ldr     p8, [%[p],  #8, mul vl]"        "\n"
347                 "ldr     p9, [%[p],  #9, mul vl]"        "\n"
348                 "ldr    p10, [%[p], #10, mul vl]"        "\n"
349                 "ldr    p11, [%[p], #11, mul vl]"        "\n"
350                 "ldr    p12, [%[p], #12, mul vl]"        "\n"
351                 "ldr    p13, [%[p], #13, mul vl]"        "\n"
352                 "ldr    p14, [%[p], #14, mul vl]"        "\n"
353                 "ldr    p15, [%[p], #15, mul vl]"        "\n"
354                 :
355                 : [p] "r"(p)
356         );
357 }
358 
359 static void
sme_za_store_data(void * addr)360 sme_za_store_data(void *addr)
361 {
362 	uint8_t *za = sme_za(addr);
363 	uint16_t svl_b = arm_sme_svl_b();
364 
365 	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
366 		asm volatile (
367                         "str    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
368                         "str    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
369                         "str    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
370                         "str    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
371                         "str    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
372                         "str    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
373                         "str    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
374                         "str    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
375                         "str    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
376                         "str    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
377                         "str    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
378                         "str    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
379                         "str    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
380                         "str    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
381                         "str    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
382                         "str    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
383                         :
384                         : [i] "r"(i),
385                           [addr] "r"(za + (i * svl_b))
386                 );
387 	}
388 
389 	if (sme_zt0_size()) {
390 		uint8_t *zt0 = sme_zt0(addr);
391 		asm volatile (
392                         "str	zt0, [%[zt0]]"
393                         :
394                         : [zt0] "r"(zt0)
395                 );
396 	}
397 
398 	set_sme_tpidr2_el0(addr, __builtin_arm_rsr64("TPIDR2_EL0"));
399 }
400 
401 static void
sme_store_data(void * addr)402 sme_store_data(void *addr)
403 {
404 	uint8_t *z = sme_z(addr);
405 	uint8_t *p = sme_p(addr);
406 
407 	sme_za_store_data(addr);
408 
409 	asm volatile (
410                 "str    z0, [%[z],   #0, mul vl]"        "\n"
411                 "str    z1, [%[z],   #1, mul vl]"        "\n"
412                 "str    z2, [%[z],   #2, mul vl]"        "\n"
413                 "str    z3, [%[z],   #3, mul vl]"        "\n"
414                 "str    z4, [%[z],   #4, mul vl]"        "\n"
415                 "str    z5, [%[z],   #5, mul vl]"        "\n"
416                 "str    z6, [%[z],   #6, mul vl]"        "\n"
417                 "str    z7, [%[z],   #7, mul vl]"        "\n"
418                 "str    z8, [%[z],   #8, mul vl]"        "\n"
419                 "str    z9, [%[z],   #9, mul vl]"        "\n"
420                 "str   z10, [%[z],  #10, mul vl]"        "\n"
421                 "str   z11, [%[z],  #11, mul vl]"        "\n"
422                 "str   z12, [%[z],  #12, mul vl]"        "\n"
423                 "str   z13, [%[z],  #13, mul vl]"        "\n"
424                 "str   z14, [%[z],  #14, mul vl]"        "\n"
425                 "str   z15, [%[z],  #15, mul vl]"        "\n"
426                 "str   z16, [%[z],  #16, mul vl]"        "\n"
427                 "str   z17, [%[z],  #17, mul vl]"        "\n"
428                 "str   z18, [%[z],  #18, mul vl]"        "\n"
429                 "str   z19, [%[z],  #19, mul vl]"        "\n"
430                 "str   z20, [%[z],  #20, mul vl]"        "\n"
431                 "str   z21, [%[z],  #21, mul vl]"        "\n"
432                 "str   z22, [%[z],  #22, mul vl]"        "\n"
433                 "str   z23, [%[z],  #23, mul vl]"        "\n"
434                 "str   z24, [%[z],  #24, mul vl]"        "\n"
435                 "str   z25, [%[z],  #25, mul vl]"        "\n"
436                 "str   z26, [%[z],  #26, mul vl]"        "\n"
437                 "str   z27, [%[z],  #27, mul vl]"        "\n"
438                 "str   z28, [%[z],  #28, mul vl]"        "\n"
439                 "str   z29, [%[z],  #29, mul vl]"        "\n"
440                 "str   z30, [%[z],  #30, mul vl]"        "\n"
441                 "str   z31, [%[z],  #31, mul vl]"        "\n"
442                 :
443                 : [z] "r"(z)
444         );
445 
446 	asm volatile (
447                 "str     p0, [%[p],  #0, mul vl]"        "\n"
448                 "str     p1, [%[p],  #1, mul vl]"        "\n"
449                 "str     p2, [%[p],  #2, mul vl]"        "\n"
450                 "str     p3, [%[p],  #3, mul vl]"        "\n"
451                 "str     p4, [%[p],  #4, mul vl]"        "\n"
452                 "str     p5, [%[p],  #5, mul vl]"        "\n"
453                 "str     p6, [%[p],  #6, mul vl]"        "\n"
454                 "str     p7, [%[p],  #7, mul vl]"        "\n"
455                 "str     p8, [%[p],  #8, mul vl]"        "\n"
456                 "str     p9, [%[p],  #9, mul vl]"        "\n"
457                 "str    p10, [%[p], #10, mul vl]"        "\n"
458                 "str    p11, [%[p], #11, mul vl]"        "\n"
459                 "str    p12, [%[p], #12, mul vl]"        "\n"
460                 "str    p13, [%[p], #13, mul vl]"        "\n"
461                 "str    p14, [%[p], #14, mul vl]"        "\n"
462                 "str    p15, [%[p], #15, mul vl]"        "\n"
463                 :
464                 : [p] "r"(p)
465         );
466 }
467 
468 static kern_return_t
sme_thread_get_state(thread_act_t thread,void * addr)469 sme_thread_get_state(thread_act_t thread, void *addr)
470 {
471 	uint8_t *za = sme_za(addr);
472 	uint8_t *z = sme_z(addr);
473 	uint8_t *p = sme_p(addr);
474 	uint16_t svl_b = arm_sme_svl_b();
475 
476 	arm_sme_state_t sme_state;
477 	mach_msg_type_number_t sme_count = ARM_SME_STATE_COUNT;
478 	kern_return_t err = thread_get_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, &sme_count);
479 	if (err) {
480 		return err;
481 	}
482 	set_sme_tpidr2_el0(addr, sme_state.__tpidr2_el0);
483 
484 	arm_sme_za_state_t za_state;
485 	mach_msg_type_number_t za_count = ARM_SME_ZA_STATE_COUNT;
486 	err = thread_get_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, &za_count);
487 	if (err) {
488 		return err;
489 	}
490 
491 	arm_sve_z_state_t z_state1, z_state2;
492 	mach_msg_type_number_t z_streaming_count = ARM_SVE_Z_STATE_COUNT;
493 	err = thread_get_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, &z_streaming_count);
494 	if (err) {
495 		return err;
496 	}
497 	err = thread_get_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, &z_streaming_count);
498 	if (err) {
499 		return err;
500 	}
501 
502 	arm_sve_p_state_t p_state;
503 	mach_msg_type_number_t p_streaming_count = ARM_SVE_P_STATE_COUNT;
504 	err = thread_get_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, &p_streaming_count);
505 	if (err) {
506 		return err;
507 	}
508 
509 	memcpy(za, za_state.__za, svl_b * svl_b);
510 
511 	size_t z_elem_size = svl_b;
512 	for (int i = 0; i < 16; i++) {
513 		memcpy(z, z_state1.__z[i], z_elem_size);
514 		z += z_elem_size;
515 	}
516 	for (int i = 0; i < 16; i++) {
517 		memcpy(z, z_state2.__z[i], z_elem_size);
518 		z += z_elem_size;
519 	}
520 
521 	size_t p_elem_size = svl_b / 8;
522 	for (int i = 0; i < 16; i++) {
523 		memcpy(p, p_state.__p[i], p_elem_size);
524 		p += p_elem_size;
525 	}
526 
527 	if (sme_zt0_size()) {
528 		uint8_t *zt0 = sme_zt0(addr);
529 
530 		arm_sme2_state_t sme2_state;
531 		mach_msg_type_number_t sme2_count = ARM_SME2_STATE_COUNT;
532 		err = thread_get_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, &sme2_count);
533 		if (err) {
534 			return err;
535 		}
536 
537 		memcpy(zt0, sme2_state.__zt0, sizeof(sme2_state.__zt0));
538 	}
539 
540 	return KERN_SUCCESS;
541 }
542 
543 static kern_return_t
sme_thread_set_state(thread_act_t thread,const void * addr)544 sme_thread_set_state(thread_act_t thread, const void *addr)
545 {
546 	const uint8_t *za = const_sme_za(addr);
547 	const uint8_t *z = const_sme_z(addr);
548 	const uint8_t *p = const_sme_p(addr);
549 	uint16_t svl_b = arm_sme_svl_b();
550 
551 	arm_sme_state_t sme_state;
552 	sme_state.__svcr = 0x3;
553 	sme_state.__svl_b = svl_b;
554 	sme_state.__tpidr2_el0 = get_sme_tpidr2_el0(addr);
555 
556 	arm_sme_za_state_t za_state;
557 	memcpy(za_state.__za, za, svl_b * svl_b);
558 
559 	arm_sve_z_state_t z_state1, z_state2;
560 	size_t z_elem_size = svl_b;
561 	for (int i = 0; i < 16; i++) {
562 		memcpy(z_state1.__z[i], z, z_elem_size);
563 		z += z_elem_size;
564 	}
565 	for (int i = 0; i < 16; i++) {
566 		memcpy(z_state2.__z[i], z, z_elem_size);
567 		z += z_elem_size;
568 	}
569 
570 	arm_sve_p_state_t p_state;
571 	size_t p_elem_size = svl_b / 8;
572 	for (int i = 0; i < 16; i++) {
573 		memcpy(p_state.__p[i], p, p_elem_size);
574 		p += p_elem_size;
575 	}
576 
577 	kern_return_t err = thread_set_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, ARM_SME_STATE_COUNT);
578 	if (err) {
579 		return err;
580 	}
581 
582 	err = thread_set_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, ARM_SVE_Z_STATE_COUNT);
583 	if (err) {
584 		return err;
585 	}
586 
587 	err = thread_set_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, ARM_SVE_Z_STATE_COUNT);
588 	if (err) {
589 		return err;
590 	}
591 
592 	err = thread_set_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, ARM_SVE_P_STATE_COUNT);
593 	if (err) {
594 		return err;
595 	}
596 
597 	err = thread_set_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, ARM_SME_ZA_STATE_COUNT);
598 	if (err) {
599 		return err;
600 	}
601 
602 	if (sme_zt0_size()) {
603 		const uint8_t *zt0 = const_sme_zt0(addr);
604 
605 		arm_sme2_state_t sme2_state;
606 		memcpy(sme2_state.__zt0, zt0, sizeof(sme2_state.__zt0));
607 
608 		err = thread_set_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, ARM_SME2_STATE_COUNT);
609 		if (err) {
610 			return err;
611 		}
612 	}
613 
614 	return KERN_SUCCESS;
615 }
616 
617 const struct arm_matrix_operations sme_operations = {
618 	.name = "SME",
619 
620 	.data_size = sme_data_size,
621 	.alloc_data = sme_alloc_data,
622 
623 	.is_available = sme_is_available,
624 	.start = sme_start,
625 	.stop = sme_stop,
626 
627 	.load_one_vector = sme_load_one_vector,
628 	.load_data = sme_load_data,
629 	.store_data = sme_store_data,
630 
631 	.thread_get_state = sme_thread_get_state,
632 	.thread_set_state = sme_thread_set_state,
633 };
634 
635 const struct arm_matrix_operations sme_za_operations = {
636 	.name = "SME (SVCR.ZA only)",
637 
638 	.data_size = sme_za_data_size,
639 	.alloc_data = sme_za_alloc_data,
640 
641 	.is_available = sme_is_available,
642 	.start = sme_za_start,
643 	.stop = sme_za_stop,
644 
645 	.load_one_vector = NULL, /* currently unused */
646 	.load_data = sme_za_load_data,
647 	.store_data = sme_za_store_data,
648 
649 	.thread_get_state = NULL, /* currently unused */
650 	.thread_set_state = NULL, /* currently unused */
651 };
652