xref: /xnu-11417.101.15/tests/arm_matrix_sme.c (revision e3723e1f17661b24996789d8afc084c0c3303b26)
1 /*
2  * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <mach/thread_act.h>
30 #include <stdint.h>
31 #include <stdlib.h>
32 #include <sys/sysctl.h>
33 
34 #include "arm_matrix.h"
35 
36 const static unsigned int SME_Z_VECTORS = 32;
37 const static unsigned int SME_P_VECTORS = 16;
38 
39 static unsigned int
sme_version(void)40 sme_version(void)
41 {
42 	static unsigned int ret = 0;
43 	static bool already_read = false;
44 
45 	if (!already_read) {
46 		size_t size = sizeof(unsigned int);
47 		unsigned int feat_sme, feat_sme2;
48 		sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
49 		sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);
50 
51 		if (feat_sme2) {
52 			ret = 2;
53 		} else if (feat_sme) {
54 			ret = 1;
55 		} else {
56 			ret = 0;
57 		}
58 
59 		already_read = true;
60 	}
61 
62 	return ret;
63 }
64 
65 static uint16_t
arm_sme_svl_b(void)66 arm_sme_svl_b(void)
67 {
68 	uint64_t ret = 0;
69 	asm volatile (
70                 "rdsvl	%[ret], #1"
71                 : [ret] "=r"(ret)
72         );
73 	return (uint16_t)ret;
74 }
75 
76 static size_t
sme_za_size(void)77 sme_za_size(void)
78 {
79 	return arm_sme_svl_b() * arm_sme_svl_b();
80 }
81 
82 static size_t
sme_z_size(void)83 sme_z_size(void)
84 {
85 	return arm_sme_svl_b() * SME_Z_VECTORS;
86 }
87 
88 static size_t
sme_p_size(void)89 sme_p_size(void)
90 {
91 	return arm_sme_svl_b() * SME_P_VECTORS / 8;
92 }
93 
94 static size_t
sme_zt0_size(void)95 sme_zt0_size(void)
96 {
97 	if (sme_version() >= 2) {
98 		return 64;
99 	} else {
100 		return 0;
101 	}
102 }
103 
104 static size_t
sme_tpidr2_size(void)105 sme_tpidr2_size(void)
106 {
107 	return sizeof(uint64_t);
108 }
109 
110 static inline uint8_t *
sme_za(void * addr)111 sme_za(void *addr)
112 {
113 	return addr;
114 }
115 
116 static inline const uint8_t *
const_sme_za(const void * addr)117 const_sme_za(const void *addr)
118 {
119 	return addr;
120 }
121 
122 static inline uint8_t *
sme_z(void * addr)123 sme_z(void *addr)
124 {
125 	return sme_za(addr) + sme_za_size();
126 }
127 
128 static inline const uint8_t *
const_sme_z(const void * addr)129 const_sme_z(const void *addr)
130 {
131 	return const_sme_za(addr) + sme_za_size();
132 }
133 
134 static inline uint8_t *
sme_p(void * addr)135 sme_p(void *addr)
136 {
137 	return sme_z(addr) + sme_z_size();
138 }
139 
140 static inline const uint8_t *
const_sme_p(const void * addr)141 const_sme_p(const void *addr)
142 {
143 	return const_sme_z(addr) + sme_z_size();
144 }
145 
146 static inline uint8_t *
sme_zt0(void * addr)147 sme_zt0(void *addr)
148 {
149 	return sme_p(addr) + sme_p_size();
150 }
151 
152 static inline const uint8_t *
const_sme_zt0(const void * addr)153 const_sme_zt0(const void *addr)
154 {
155 	return const_sme_p(addr) + sme_p_size();
156 }
157 
158 static size_t
sme_data_size(void)159 sme_data_size(void)
160 {
161 	return sme_za_size() + sme_z_size() + sme_p_size() + sme_zt0_size() + sme_tpidr2_size();
162 }
163 
164 static inline void
set_sme_tpidr2_el0(void * addr,uint64_t val)165 set_sme_tpidr2_el0(void *addr, uint64_t val)
166 {
167 	uint64_t *ptr = (uint64_t *)(sme_zt0(addr) + sme_zt0_size());
168 	*ptr = val;
169 }
170 
171 static inline uint64_t
get_sme_tpidr2_el0(const void * addr)172 get_sme_tpidr2_el0(const void *addr)
173 {
174 	const uint64_t *ptr = (const uint64_t *)(const_sme_zt0(addr) + sme_zt0_size());
175 	return *ptr;
176 }
177 
178 static void *
sme_alloc_data(void)179 sme_alloc_data(void)
180 {
181 	return malloc(sme_data_size());
182 }
183 
184 static bool
sme_is_available(void)185 sme_is_available(void)
186 {
187 	return sme_version() > 0;
188 }
189 
190 static void
sme_start(void)191 sme_start(void)
192 {
193 	asm volatile ("smstart");
194 }
195 
196 static void
sme_stop(void)197 sme_stop(void)
198 {
199 	asm volatile ("smstop");
200 }
201 
202 static void
sme_load_one_vector(const void * addr)203 sme_load_one_vector(const void *addr)
204 {
205 	asm volatile (
206                 "mov    w12, #0"                "\n"
207                 "ldr    za[w12, #0], [%[addr]]" "\n"
208                 :
209                 : [addr] "r"(addr)
210                 : "w12"
211         );
212 }
213 
214 static void
sme_load_data(const void * addr)215 sme_load_data(const void *addr)
216 {
217 	const uint8_t *za = const_sme_za(addr);
218 	const uint8_t *z = const_sme_z(addr);
219 	const uint8_t *p = const_sme_p(addr);
220 	uint16_t svl_b = arm_sme_svl_b();
221 
222 	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
223 		asm volatile (
224                         "ldr    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
225                         "ldr    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
226                         "ldr    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
227                         "ldr    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
228                         "ldr    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
229                         "ldr    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
230                         "ldr    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
231                         "ldr    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
232                         "ldr    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
233                         "ldr    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
234                         "ldr    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
235                         "ldr    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
236                         "ldr    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
237                         "ldr    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
238                         "ldr    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
239                         "ldr    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
240                         :
241                         : [i] "r"(i),
242                           [addr] "r"(za + (i * svl_b))
243                 );
244 	}
245 
246 	asm volatile (
247                 "ldr    z0, [%[z],   #0, mul vl]"        "\n"
248                 "ldr    z1, [%[z],   #1, mul vl]"        "\n"
249                 "ldr    z2, [%[z],   #2, mul vl]"        "\n"
250                 "ldr    z3, [%[z],   #3, mul vl]"        "\n"
251                 "ldr    z4, [%[z],   #4, mul vl]"        "\n"
252                 "ldr    z5, [%[z],   #5, mul vl]"        "\n"
253                 "ldr    z6, [%[z],   #6, mul vl]"        "\n"
254                 "ldr    z7, [%[z],   #7, mul vl]"        "\n"
255                 "ldr    z8, [%[z],   #8, mul vl]"        "\n"
256                 "ldr    z9, [%[z],   #9, mul vl]"        "\n"
257                 "ldr   z10, [%[z],  #10, mul vl]"        "\n"
258                 "ldr   z11, [%[z],  #11, mul vl]"        "\n"
259                 "ldr   z12, [%[z],  #12, mul vl]"        "\n"
260                 "ldr   z13, [%[z],  #13, mul vl]"        "\n"
261                 "ldr   z14, [%[z],  #14, mul vl]"        "\n"
262                 "ldr   z15, [%[z],  #15, mul vl]"        "\n"
263                 "ldr   z16, [%[z],  #16, mul vl]"        "\n"
264                 "ldr   z17, [%[z],  #17, mul vl]"        "\n"
265                 "ldr   z18, [%[z],  #18, mul vl]"        "\n"
266                 "ldr   z19, [%[z],  #19, mul vl]"        "\n"
267                 "ldr   z20, [%[z],  #20, mul vl]"        "\n"
268                 "ldr   z21, [%[z],  #21, mul vl]"        "\n"
269                 "ldr   z22, [%[z],  #22, mul vl]"        "\n"
270                 "ldr   z23, [%[z],  #23, mul vl]"        "\n"
271                 "ldr   z24, [%[z],  #24, mul vl]"        "\n"
272                 "ldr   z25, [%[z],  #25, mul vl]"        "\n"
273                 "ldr   z26, [%[z],  #26, mul vl]"        "\n"
274                 "ldr   z27, [%[z],  #27, mul vl]"        "\n"
275                 "ldr   z28, [%[z],  #28, mul vl]"        "\n"
276                 "ldr   z29, [%[z],  #29, mul vl]"        "\n"
277                 "ldr   z30, [%[z],  #30, mul vl]"        "\n"
278                 "ldr   z31, [%[z],  #31, mul vl]"        "\n"
279                 :
280                 : [z] "r"(z)
281         );
282 
283 	asm volatile (
284                 "ldr     p0, [%[p],  #0, mul vl]"        "\n"
285                 "ldr     p1, [%[p],  #1, mul vl]"        "\n"
286                 "ldr     p2, [%[p],  #2, mul vl]"        "\n"
287                 "ldr     p3, [%[p],  #3, mul vl]"        "\n"
288                 "ldr     p4, [%[p],  #4, mul vl]"        "\n"
289                 "ldr     p5, [%[p],  #5, mul vl]"        "\n"
290                 "ldr     p6, [%[p],  #6, mul vl]"        "\n"
291                 "ldr     p7, [%[p],  #7, mul vl]"        "\n"
292                 "ldr     p8, [%[p],  #8, mul vl]"        "\n"
293                 "ldr     p9, [%[p],  #9, mul vl]"        "\n"
294                 "ldr    p10, [%[p], #10, mul vl]"        "\n"
295                 "ldr    p11, [%[p], #11, mul vl]"        "\n"
296                 "ldr    p12, [%[p], #12, mul vl]"        "\n"
297                 "ldr    p13, [%[p], #13, mul vl]"        "\n"
298                 "ldr    p14, [%[p], #14, mul vl]"        "\n"
299                 "ldr    p15, [%[p], #15, mul vl]"        "\n"
300                 :
301                 : [p] "r"(p)
302         );
303 
304 	if (sme_zt0_size()) {
305 		const uint8_t *zt0 = const_sme_zt0(addr);
306 		asm volatile (
307                         "ldr	zt0, [%[zt0]]"
308                         :
309                         : [zt0] "r"(zt0)
310                 );
311 	}
312 
313 	__builtin_arm_wsr64("TPIDR2_EL0", get_sme_tpidr2_el0(addr));
314 }
315 
316 static void
sme_store_data(void * addr)317 sme_store_data(void *addr)
318 {
319 	uint8_t *za = sme_za(addr);
320 	uint8_t *z = sme_z(addr);
321 	uint8_t *p = sme_p(addr);
322 	uint16_t svl_b = arm_sme_svl_b();
323 
324 	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
325 		asm volatile (
326                         "str    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
327                         "str    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
328                         "str    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
329                         "str    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
330                         "str    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
331                         "str    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
332                         "str    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
333                         "str    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
334                         "str    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
335                         "str    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
336                         "str    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
337                         "str    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
338                         "str    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
339                         "str    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
340                         "str    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
341                         "str    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
342                         :
343                         : [i] "r"(i),
344                           [addr] "r"(za + (i * svl_b))
345                 );
346 	}
347 
348 	asm volatile (
349                 "str    z0, [%[z],   #0, mul vl]"        "\n"
350                 "str    z1, [%[z],   #1, mul vl]"        "\n"
351                 "str    z2, [%[z],   #2, mul vl]"        "\n"
352                 "str    z3, [%[z],   #3, mul vl]"        "\n"
353                 "str    z4, [%[z],   #4, mul vl]"        "\n"
354                 "str    z5, [%[z],   #5, mul vl]"        "\n"
355                 "str    z6, [%[z],   #6, mul vl]"        "\n"
356                 "str    z7, [%[z],   #7, mul vl]"        "\n"
357                 "str    z8, [%[z],   #8, mul vl]"        "\n"
358                 "str    z9, [%[z],   #9, mul vl]"        "\n"
359                 "str   z10, [%[z],  #10, mul vl]"        "\n"
360                 "str   z11, [%[z],  #11, mul vl]"        "\n"
361                 "str   z12, [%[z],  #12, mul vl]"        "\n"
362                 "str   z13, [%[z],  #13, mul vl]"        "\n"
363                 "str   z14, [%[z],  #14, mul vl]"        "\n"
364                 "str   z15, [%[z],  #15, mul vl]"        "\n"
365                 "str   z16, [%[z],  #16, mul vl]"        "\n"
366                 "str   z17, [%[z],  #17, mul vl]"        "\n"
367                 "str   z18, [%[z],  #18, mul vl]"        "\n"
368                 "str   z19, [%[z],  #19, mul vl]"        "\n"
369                 "str   z20, [%[z],  #20, mul vl]"        "\n"
370                 "str   z21, [%[z],  #21, mul vl]"        "\n"
371                 "str   z22, [%[z],  #22, mul vl]"        "\n"
372                 "str   z23, [%[z],  #23, mul vl]"        "\n"
373                 "str   z24, [%[z],  #24, mul vl]"        "\n"
374                 "str   z25, [%[z],  #25, mul vl]"        "\n"
375                 "str   z26, [%[z],  #26, mul vl]"        "\n"
376                 "str   z27, [%[z],  #27, mul vl]"        "\n"
377                 "str   z28, [%[z],  #28, mul vl]"        "\n"
378                 "str   z29, [%[z],  #29, mul vl]"        "\n"
379                 "str   z30, [%[z],  #30, mul vl]"        "\n"
380                 "str   z31, [%[z],  #31, mul vl]"        "\n"
381                 :
382                 : [z] "r"(z)
383         );
384 
385 	asm volatile (
386                 "str     p0, [%[p],  #0, mul vl]"        "\n"
387                 "str     p1, [%[p],  #1, mul vl]"        "\n"
388                 "str     p2, [%[p],  #2, mul vl]"        "\n"
389                 "str     p3, [%[p],  #3, mul vl]"        "\n"
390                 "str     p4, [%[p],  #4, mul vl]"        "\n"
391                 "str     p5, [%[p],  #5, mul vl]"        "\n"
392                 "str     p6, [%[p],  #6, mul vl]"        "\n"
393                 "str     p7, [%[p],  #7, mul vl]"        "\n"
394                 "str     p8, [%[p],  #8, mul vl]"        "\n"
395                 "str     p9, [%[p],  #9, mul vl]"        "\n"
396                 "str    p10, [%[p], #10, mul vl]"        "\n"
397                 "str    p11, [%[p], #11, mul vl]"        "\n"
398                 "str    p12, [%[p], #12, mul vl]"        "\n"
399                 "str    p13, [%[p], #13, mul vl]"        "\n"
400                 "str    p14, [%[p], #14, mul vl]"        "\n"
401                 "str    p15, [%[p], #15, mul vl]"        "\n"
402                 :
403                 : [p] "r"(p)
404         );
405 
406 	if (sme_zt0_size()) {
407 		uint8_t *zt0 = sme_zt0(addr);
408 		asm volatile (
409                         "str	zt0, [%[zt0]]"
410                         :
411                         : [zt0] "r"(zt0)
412                 );
413 	}
414 
415 	set_sme_tpidr2_el0(addr, __builtin_arm_rsr64("TPIDR2_EL0"));
416 }
417 
418 static kern_return_t
sme_thread_get_state(thread_act_t thread,void * addr)419 sme_thread_get_state(thread_act_t thread, void *addr)
420 {
421 	uint8_t *za = sme_za(addr);
422 	uint8_t *z = sme_z(addr);
423 	uint8_t *p = sme_p(addr);
424 	uint16_t svl_b = arm_sme_svl_b();
425 
426 	arm_sme_state_t sme_state;
427 	mach_msg_type_number_t sme_count = ARM_SME_STATE_COUNT;
428 	kern_return_t err = thread_get_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, &sme_count);
429 	if (err) {
430 		return err;
431 	}
432 	set_sme_tpidr2_el0(addr, sme_state.__tpidr2_el0);
433 
434 	arm_sme_za_state_t za_state;
435 	mach_msg_type_number_t za_count = ARM_SME_ZA_STATE_COUNT;
436 	err = thread_get_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, &za_count);
437 	if (err) {
438 		return err;
439 	}
440 
441 	arm_sve_z_state_t z_state1, z_state2;
442 	mach_msg_type_number_t z_streaming_count = ARM_SVE_Z_STATE_COUNT;
443 	err = thread_get_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, &z_streaming_count);
444 	if (err) {
445 		return err;
446 	}
447 	err = thread_get_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, &z_streaming_count);
448 	if (err) {
449 		return err;
450 	}
451 
452 	arm_sve_p_state_t p_state;
453 	mach_msg_type_number_t p_streaming_count = ARM_SVE_P_STATE_COUNT;
454 	err = thread_get_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, &p_streaming_count);
455 	if (err) {
456 		return err;
457 	}
458 
459 	memcpy(za, za_state.__za, svl_b * svl_b);
460 
461 	size_t z_elem_size = svl_b;
462 	for (int i = 0; i < 16; i++) {
463 		memcpy(z, z_state1.__z[i], z_elem_size);
464 		z += z_elem_size;
465 	}
466 	for (int i = 0; i < 16; i++) {
467 		memcpy(z, z_state2.__z[i], z_elem_size);
468 		z += z_elem_size;
469 	}
470 
471 	size_t p_elem_size = svl_b / 8;
472 	for (int i = 0; i < 16; i++) {
473 		memcpy(p, p_state.__p[i], p_elem_size);
474 		p += p_elem_size;
475 	}
476 
477 	if (sme_zt0_size()) {
478 		uint8_t *zt0 = sme_zt0(addr);
479 
480 		arm_sme2_state_t sme2_state;
481 		mach_msg_type_number_t sme2_count = ARM_SME2_STATE_COUNT;
482 		err = thread_get_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, &sme2_count);
483 		if (err) {
484 			return err;
485 		}
486 
487 		memcpy(zt0, sme2_state.__zt0, sizeof(sme2_state.__zt0));
488 	}
489 
490 	return KERN_SUCCESS;
491 }
492 
493 static kern_return_t
sme_thread_set_state(thread_act_t thread,const void * addr)494 sme_thread_set_state(thread_act_t thread, const void *addr)
495 {
496 	const uint8_t *za = const_sme_za(addr);
497 	const uint8_t *z = const_sme_z(addr);
498 	const uint8_t *p = const_sme_p(addr);
499 	uint16_t svl_b = arm_sme_svl_b();
500 
501 	arm_sme_state_t sme_state;
502 	sme_state.__svcr = 0x3;
503 	sme_state.__svl_b = svl_b;
504 	sme_state.__tpidr2_el0 = get_sme_tpidr2_el0(addr);
505 
506 	arm_sme_za_state_t za_state;
507 	memcpy(za_state.__za, za, svl_b * svl_b);
508 
509 	arm_sve_z_state_t z_state1, z_state2;
510 	size_t z_elem_size = svl_b;
511 	for (int i = 0; i < 16; i++) {
512 		memcpy(z_state1.__z[i], z, z_elem_size);
513 		z += z_elem_size;
514 	}
515 	for (int i = 0; i < 16; i++) {
516 		memcpy(z_state2.__z[i], z, z_elem_size);
517 		z += z_elem_size;
518 	}
519 
520 	arm_sve_p_state_t p_state;
521 	size_t p_elem_size = svl_b / 8;
522 	for (int i = 0; i < 16; i++) {
523 		memcpy(p_state.__p[i], p, p_elem_size);
524 		p += p_elem_size;
525 	}
526 
527 	kern_return_t err = thread_set_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, ARM_SME_STATE_COUNT);
528 	if (err) {
529 		return err;
530 	}
531 
532 	err = thread_set_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, ARM_SVE_Z_STATE_COUNT);
533 	if (err) {
534 		return err;
535 	}
536 
537 	err = thread_set_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, ARM_SVE_Z_STATE_COUNT);
538 	if (err) {
539 		return err;
540 	}
541 
542 	err = thread_set_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, ARM_SVE_P_STATE_COUNT);
543 	if (err) {
544 		return err;
545 	}
546 
547 	err = thread_set_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, ARM_SME_ZA_STATE_COUNT);
548 	if (err) {
549 		return err;
550 	}
551 
552 	if (sme_zt0_size()) {
553 		const uint8_t *zt0 = const_sme_zt0(addr);
554 
555 		arm_sme2_state_t sme2_state;
556 		memcpy(sme2_state.__zt0, zt0, sizeof(sme2_state.__zt0));
557 
558 		err = thread_set_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, ARM_SME2_STATE_COUNT);
559 		if (err) {
560 			return err;
561 		}
562 	}
563 
564 	return KERN_SUCCESS;
565 }
566 
567 const struct arm_matrix_operations sme_operations = {
568 	.name = "SME",
569 
570 	.data_size = sme_data_size,
571 	.alloc_data = sme_alloc_data,
572 
573 	.is_available = sme_is_available,
574 	.start = sme_start,
575 	.stop = sme_stop,
576 
577 	.load_one_vector = sme_load_one_vector,
578 	.load_data = sme_load_data,
579 	.store_data = sme_store_data,
580 
581 	.thread_get_state = sme_thread_get_state,
582 	.thread_set_state = sme_thread_set_state,
583 };
584