1 /*
2 * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28
29 #include <mach/thread_act.h>
30 #include <stdint.h>
31 #include <stdlib.h>
32 #include <sys/sysctl.h>
33
34 #include "arm_matrix.h"
35
36 const static unsigned int SME_Z_VECTORS = 32;
37 const static unsigned int SME_P_VECTORS = 16;
38
39 static unsigned int
sme_version(void)40 sme_version(void)
41 {
42 static unsigned int ret = 0;
43 static bool already_read = false;
44
45 if (!already_read) {
46 size_t size = sizeof(unsigned int);
47 unsigned int feat_sme, feat_sme2;
48 sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
49 sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);
50
51 if (feat_sme2) {
52 ret = 2;
53 } else if (feat_sme) {
54 ret = 1;
55 } else {
56 ret = 0;
57 }
58
59 already_read = true;
60 }
61
62 return ret;
63 }
64
65 static uint16_t
arm_sme_svl_b(void)66 arm_sme_svl_b(void)
67 {
68 uint64_t ret = 0;
69 asm volatile (
70 "rdsvl %[ret], #1"
71 : [ret] "=r"(ret)
72 );
73 return (uint16_t)ret;
74 }
75
76 static size_t
sme_za_size(void)77 sme_za_size(void)
78 {
79 return arm_sme_svl_b() * arm_sme_svl_b();
80 }
81
82 static size_t
sme_z_size(void)83 sme_z_size(void)
84 {
85 return arm_sme_svl_b() * SME_Z_VECTORS;
86 }
87
88 static size_t
sme_p_size(void)89 sme_p_size(void)
90 {
91 return arm_sme_svl_b() * SME_P_VECTORS / 8;
92 }
93
94 static size_t
sme_zt0_size(void)95 sme_zt0_size(void)
96 {
97 if (sme_version() >= 2) {
98 return 64;
99 } else {
100 return 0;
101 }
102 }
103
104 static size_t
sme_tpidr2_size(void)105 sme_tpidr2_size(void)
106 {
107 return sizeof(uint64_t);
108 }
109
110 static inline uint8_t *
sme_za(void * addr)111 sme_za(void *addr)
112 {
113 return addr;
114 }
115
116 static inline const uint8_t *
const_sme_za(const void * addr)117 const_sme_za(const void *addr)
118 {
119 return addr;
120 }
121
122 static inline uint8_t *
sme_zt0(void * addr)123 sme_zt0(void *addr)
124 {
125 return sme_za(addr) + sme_za_size();
126 }
127
128 static inline const uint8_t *
const_sme_zt0(const void * addr)129 const_sme_zt0(const void *addr)
130 {
131 return const_sme_za(addr) + sme_za_size();
132 }
133
134 static inline uint8_t *
sme_tpidr2_el0(void * addr)135 sme_tpidr2_el0(void *addr)
136 {
137 return sme_zt0(addr) + sme_zt0_size();
138 }
139
140 static inline const uint8_t *
const_sme_tpidr2_el0(const void * addr)141 const_sme_tpidr2_el0(const void *addr)
142 {
143 return const_sme_zt0(addr) + sme_zt0_size();
144 }
145
146 static inline uint8_t *
sme_z(void * addr)147 sme_z(void *addr)
148 {
149 return sme_tpidr2_el0(addr) + sizeof(uint64_t);
150 }
151
152 static inline const uint8_t *
const_sme_z(const void * addr)153 const_sme_z(const void *addr)
154 {
155 return const_sme_tpidr2_el0(addr) + sizeof(uint64_t);
156 }
157
158 static inline uint8_t *
sme_p(void * addr)159 sme_p(void *addr)
160 {
161 return sme_z(addr) + sme_z_size();
162 }
163
164 static inline const uint8_t *
const_sme_p(const void * addr)165 const_sme_p(const void *addr)
166 {
167 return const_sme_z(addr) + sme_z_size();
168 }
169
170 static size_t
sme_data_size(void)171 sme_data_size(void)
172 {
173 return sme_za_size() + sme_zt0_size() + sme_tpidr2_size() + sme_z_size() + sme_p_size();
174 }
175
176 static size_t
sme_za_data_size(void)177 sme_za_data_size(void)
178 {
179 return sme_za_size() + sme_zt0_size() + sme_tpidr2_size();
180 }
181
182 static inline void
set_sme_tpidr2_el0(void * addr,uint64_t val)183 set_sme_tpidr2_el0(void *addr, uint64_t val)
184 {
185 uint64_t *ptr = (uint64_t *)(sme_tpidr2_el0(addr));
186 *ptr = val;
187 }
188
189 static inline uint64_t
get_sme_tpidr2_el0(const void * addr)190 get_sme_tpidr2_el0(const void *addr)
191 {
192 const uint64_t *ptr = (const uint64_t *)(const_sme_tpidr2_el0(addr));
193 return *ptr;
194 }
195
196 static void *
sme_alloc_data(void)197 sme_alloc_data(void)
198 {
199 return malloc(sme_data_size());
200 }
201
202 static void *
sme_za_alloc_data(void)203 sme_za_alloc_data(void)
204 {
205 return malloc(sme_za_data_size());
206 }
207
208 static bool
sme_is_available(void)209 sme_is_available(void)
210 {
211 return sme_version() > 0;
212 }
213
214 static void
sme_start(void)215 sme_start(void)
216 {
217 asm volatile ("smstart");
218 }
219
220 static void
sme_za_start(void)221 sme_za_start(void)
222 {
223 asm volatile ("smstart za");
224 }
225
226 static void
sme_stop(void)227 sme_stop(void)
228 {
229 asm volatile ("smstop");
230 }
231
232 static void
sme_za_stop(void)233 sme_za_stop(void)
234 {
235 asm volatile ("smstop za");
236 }
237
238 static void
sme_load_one_vector(const void * addr)239 sme_load_one_vector(const void *addr)
240 {
241 asm volatile (
242 "mov w12, #0" "\n"
243 "ldr za[w12, #0], [%[addr]]" "\n"
244 :
245 : [addr] "r"(addr)
246 : "w12"
247 );
248 }
249
250 static void
sme_za_load_data(const void * addr)251 sme_za_load_data(const void *addr)
252 {
253 const uint8_t *za = const_sme_za(addr);
254 uint16_t svl_b = arm_sme_svl_b();
255
256 for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
257 asm volatile (
258 "ldr za[%w[i], #0], [%[addr], #0, mul vl]" "\n"
259 "ldr za[%w[i], #1], [%[addr], #1, mul vl]" "\n"
260 "ldr za[%w[i], #2], [%[addr], #2, mul vl]" "\n"
261 "ldr za[%w[i], #3], [%[addr], #3, mul vl]" "\n"
262 "ldr za[%w[i], #4], [%[addr], #4, mul vl]" "\n"
263 "ldr za[%w[i], #5], [%[addr], #5, mul vl]" "\n"
264 "ldr za[%w[i], #6], [%[addr], #6, mul vl]" "\n"
265 "ldr za[%w[i], #7], [%[addr], #7, mul vl]" "\n"
266 "ldr za[%w[i], #8], [%[addr], #8, mul vl]" "\n"
267 "ldr za[%w[i], #9], [%[addr], #9, mul vl]" "\n"
268 "ldr za[%w[i], #10], [%[addr], #10, mul vl]" "\n"
269 "ldr za[%w[i], #11], [%[addr], #11, mul vl]" "\n"
270 "ldr za[%w[i], #12], [%[addr], #12, mul vl]" "\n"
271 "ldr za[%w[i], #13], [%[addr], #13, mul vl]" "\n"
272 "ldr za[%w[i], #14], [%[addr], #14, mul vl]" "\n"
273 "ldr za[%w[i], #15], [%[addr], #15, mul vl]" "\n"
274 :
275 : [i] "r"(i),
276 [addr] "r"(za + (i * svl_b))
277 );
278 }
279
280 if (sme_zt0_size()) {
281 const uint8_t *zt0 = const_sme_zt0(addr);
282 asm volatile (
283 "ldr zt0, [%[zt0]]"
284 :
285 : [zt0] "r"(zt0)
286 );
287 }
288
289 __builtin_arm_wsr64("TPIDR2_EL0", get_sme_tpidr2_el0(addr));
290 }
291
292 static void
sme_load_data(const void * addr)293 sme_load_data(const void *addr)
294 {
295 const uint8_t *z = const_sme_z(addr);
296 const uint8_t *p = const_sme_p(addr);
297
298 sme_za_load_data(addr);
299
300 asm volatile (
301 "ldr z0, [%[z], #0, mul vl]" "\n"
302 "ldr z1, [%[z], #1, mul vl]" "\n"
303 "ldr z2, [%[z], #2, mul vl]" "\n"
304 "ldr z3, [%[z], #3, mul vl]" "\n"
305 "ldr z4, [%[z], #4, mul vl]" "\n"
306 "ldr z5, [%[z], #5, mul vl]" "\n"
307 "ldr z6, [%[z], #6, mul vl]" "\n"
308 "ldr z7, [%[z], #7, mul vl]" "\n"
309 "ldr z8, [%[z], #8, mul vl]" "\n"
310 "ldr z9, [%[z], #9, mul vl]" "\n"
311 "ldr z10, [%[z], #10, mul vl]" "\n"
312 "ldr z11, [%[z], #11, mul vl]" "\n"
313 "ldr z12, [%[z], #12, mul vl]" "\n"
314 "ldr z13, [%[z], #13, mul vl]" "\n"
315 "ldr z14, [%[z], #14, mul vl]" "\n"
316 "ldr z15, [%[z], #15, mul vl]" "\n"
317 "ldr z16, [%[z], #16, mul vl]" "\n"
318 "ldr z17, [%[z], #17, mul vl]" "\n"
319 "ldr z18, [%[z], #18, mul vl]" "\n"
320 "ldr z19, [%[z], #19, mul vl]" "\n"
321 "ldr z20, [%[z], #20, mul vl]" "\n"
322 "ldr z21, [%[z], #21, mul vl]" "\n"
323 "ldr z22, [%[z], #22, mul vl]" "\n"
324 "ldr z23, [%[z], #23, mul vl]" "\n"
325 "ldr z24, [%[z], #24, mul vl]" "\n"
326 "ldr z25, [%[z], #25, mul vl]" "\n"
327 "ldr z26, [%[z], #26, mul vl]" "\n"
328 "ldr z27, [%[z], #27, mul vl]" "\n"
329 "ldr z28, [%[z], #28, mul vl]" "\n"
330 "ldr z29, [%[z], #29, mul vl]" "\n"
331 "ldr z30, [%[z], #30, mul vl]" "\n"
332 "ldr z31, [%[z], #31, mul vl]" "\n"
333 :
334 : [z] "r"(z)
335 );
336
337 asm volatile (
338 "ldr p0, [%[p], #0, mul vl]" "\n"
339 "ldr p1, [%[p], #1, mul vl]" "\n"
340 "ldr p2, [%[p], #2, mul vl]" "\n"
341 "ldr p3, [%[p], #3, mul vl]" "\n"
342 "ldr p4, [%[p], #4, mul vl]" "\n"
343 "ldr p5, [%[p], #5, mul vl]" "\n"
344 "ldr p6, [%[p], #6, mul vl]" "\n"
345 "ldr p7, [%[p], #7, mul vl]" "\n"
346 "ldr p8, [%[p], #8, mul vl]" "\n"
347 "ldr p9, [%[p], #9, mul vl]" "\n"
348 "ldr p10, [%[p], #10, mul vl]" "\n"
349 "ldr p11, [%[p], #11, mul vl]" "\n"
350 "ldr p12, [%[p], #12, mul vl]" "\n"
351 "ldr p13, [%[p], #13, mul vl]" "\n"
352 "ldr p14, [%[p], #14, mul vl]" "\n"
353 "ldr p15, [%[p], #15, mul vl]" "\n"
354 :
355 : [p] "r"(p)
356 );
357 }
358
359 static void
sme_za_store_data(void * addr)360 sme_za_store_data(void *addr)
361 {
362 uint8_t *za = sme_za(addr);
363 uint16_t svl_b = arm_sme_svl_b();
364
365 for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
366 asm volatile (
367 "str za[%w[i], #0], [%[addr], #0, mul vl]" "\n"
368 "str za[%w[i], #1], [%[addr], #1, mul vl]" "\n"
369 "str za[%w[i], #2], [%[addr], #2, mul vl]" "\n"
370 "str za[%w[i], #3], [%[addr], #3, mul vl]" "\n"
371 "str za[%w[i], #4], [%[addr], #4, mul vl]" "\n"
372 "str za[%w[i], #5], [%[addr], #5, mul vl]" "\n"
373 "str za[%w[i], #6], [%[addr], #6, mul vl]" "\n"
374 "str za[%w[i], #7], [%[addr], #7, mul vl]" "\n"
375 "str za[%w[i], #8], [%[addr], #8, mul vl]" "\n"
376 "str za[%w[i], #9], [%[addr], #9, mul vl]" "\n"
377 "str za[%w[i], #10], [%[addr], #10, mul vl]" "\n"
378 "str za[%w[i], #11], [%[addr], #11, mul vl]" "\n"
379 "str za[%w[i], #12], [%[addr], #12, mul vl]" "\n"
380 "str za[%w[i], #13], [%[addr], #13, mul vl]" "\n"
381 "str za[%w[i], #14], [%[addr], #14, mul vl]" "\n"
382 "str za[%w[i], #15], [%[addr], #15, mul vl]" "\n"
383 :
384 : [i] "r"(i),
385 [addr] "r"(za + (i * svl_b))
386 );
387 }
388
389 if (sme_zt0_size()) {
390 uint8_t *zt0 = sme_zt0(addr);
391 asm volatile (
392 "str zt0, [%[zt0]]"
393 :
394 : [zt0] "r"(zt0)
395 );
396 }
397
398 set_sme_tpidr2_el0(addr, __builtin_arm_rsr64("TPIDR2_EL0"));
399 }
400
401 static void
sme_store_data(void * addr)402 sme_store_data(void *addr)
403 {
404 uint8_t *z = sme_z(addr);
405 uint8_t *p = sme_p(addr);
406
407 sme_za_store_data(addr);
408
409 asm volatile (
410 "str z0, [%[z], #0, mul vl]" "\n"
411 "str z1, [%[z], #1, mul vl]" "\n"
412 "str z2, [%[z], #2, mul vl]" "\n"
413 "str z3, [%[z], #3, mul vl]" "\n"
414 "str z4, [%[z], #4, mul vl]" "\n"
415 "str z5, [%[z], #5, mul vl]" "\n"
416 "str z6, [%[z], #6, mul vl]" "\n"
417 "str z7, [%[z], #7, mul vl]" "\n"
418 "str z8, [%[z], #8, mul vl]" "\n"
419 "str z9, [%[z], #9, mul vl]" "\n"
420 "str z10, [%[z], #10, mul vl]" "\n"
421 "str z11, [%[z], #11, mul vl]" "\n"
422 "str z12, [%[z], #12, mul vl]" "\n"
423 "str z13, [%[z], #13, mul vl]" "\n"
424 "str z14, [%[z], #14, mul vl]" "\n"
425 "str z15, [%[z], #15, mul vl]" "\n"
426 "str z16, [%[z], #16, mul vl]" "\n"
427 "str z17, [%[z], #17, mul vl]" "\n"
428 "str z18, [%[z], #18, mul vl]" "\n"
429 "str z19, [%[z], #19, mul vl]" "\n"
430 "str z20, [%[z], #20, mul vl]" "\n"
431 "str z21, [%[z], #21, mul vl]" "\n"
432 "str z22, [%[z], #22, mul vl]" "\n"
433 "str z23, [%[z], #23, mul vl]" "\n"
434 "str z24, [%[z], #24, mul vl]" "\n"
435 "str z25, [%[z], #25, mul vl]" "\n"
436 "str z26, [%[z], #26, mul vl]" "\n"
437 "str z27, [%[z], #27, mul vl]" "\n"
438 "str z28, [%[z], #28, mul vl]" "\n"
439 "str z29, [%[z], #29, mul vl]" "\n"
440 "str z30, [%[z], #30, mul vl]" "\n"
441 "str z31, [%[z], #31, mul vl]" "\n"
442 :
443 : [z] "r"(z)
444 );
445
446 asm volatile (
447 "str p0, [%[p], #0, mul vl]" "\n"
448 "str p1, [%[p], #1, mul vl]" "\n"
449 "str p2, [%[p], #2, mul vl]" "\n"
450 "str p3, [%[p], #3, mul vl]" "\n"
451 "str p4, [%[p], #4, mul vl]" "\n"
452 "str p5, [%[p], #5, mul vl]" "\n"
453 "str p6, [%[p], #6, mul vl]" "\n"
454 "str p7, [%[p], #7, mul vl]" "\n"
455 "str p8, [%[p], #8, mul vl]" "\n"
456 "str p9, [%[p], #9, mul vl]" "\n"
457 "str p10, [%[p], #10, mul vl]" "\n"
458 "str p11, [%[p], #11, mul vl]" "\n"
459 "str p12, [%[p], #12, mul vl]" "\n"
460 "str p13, [%[p], #13, mul vl]" "\n"
461 "str p14, [%[p], #14, mul vl]" "\n"
462 "str p15, [%[p], #15, mul vl]" "\n"
463 :
464 : [p] "r"(p)
465 );
466 }
467
468 static kern_return_t
sme_thread_get_state(thread_act_t thread,void * addr)469 sme_thread_get_state(thread_act_t thread, void *addr)
470 {
471 uint8_t *za = sme_za(addr);
472 uint8_t *z = sme_z(addr);
473 uint8_t *p = sme_p(addr);
474 uint16_t svl_b = arm_sme_svl_b();
475
476 arm_sme_state_t sme_state;
477 mach_msg_type_number_t sme_count = ARM_SME_STATE_COUNT;
478 kern_return_t err = thread_get_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, &sme_count);
479 if (err) {
480 return err;
481 }
482 set_sme_tpidr2_el0(addr, sme_state.__tpidr2_el0);
483
484 arm_sme_za_state_t za_state;
485 mach_msg_type_number_t za_count = ARM_SME_ZA_STATE_COUNT;
486 err = thread_get_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, &za_count);
487 if (err) {
488 return err;
489 }
490
491 arm_sve_z_state_t z_state1, z_state2;
492 mach_msg_type_number_t z_streaming_count = ARM_SVE_Z_STATE_COUNT;
493 err = thread_get_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, &z_streaming_count);
494 if (err) {
495 return err;
496 }
497 err = thread_get_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, &z_streaming_count);
498 if (err) {
499 return err;
500 }
501
502 arm_sve_p_state_t p_state;
503 mach_msg_type_number_t p_streaming_count = ARM_SVE_P_STATE_COUNT;
504 err = thread_get_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, &p_streaming_count);
505 if (err) {
506 return err;
507 }
508
509 memcpy(za, za_state.__za, svl_b * svl_b);
510
511 size_t z_elem_size = svl_b;
512 for (int i = 0; i < 16; i++) {
513 memcpy(z, z_state1.__z[i], z_elem_size);
514 z += z_elem_size;
515 }
516 for (int i = 0; i < 16; i++) {
517 memcpy(z, z_state2.__z[i], z_elem_size);
518 z += z_elem_size;
519 }
520
521 size_t p_elem_size = svl_b / 8;
522 for (int i = 0; i < 16; i++) {
523 memcpy(p, p_state.__p[i], p_elem_size);
524 p += p_elem_size;
525 }
526
527 if (sme_zt0_size()) {
528 uint8_t *zt0 = sme_zt0(addr);
529
530 arm_sme2_state_t sme2_state;
531 mach_msg_type_number_t sme2_count = ARM_SME2_STATE_COUNT;
532 err = thread_get_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, &sme2_count);
533 if (err) {
534 return err;
535 }
536
537 memcpy(zt0, sme2_state.__zt0, sizeof(sme2_state.__zt0));
538 }
539
540 return KERN_SUCCESS;
541 }
542
543 static kern_return_t
sme_thread_set_state(thread_act_t thread,const void * addr)544 sme_thread_set_state(thread_act_t thread, const void *addr)
545 {
546 const uint8_t *za = const_sme_za(addr);
547 const uint8_t *z = const_sme_z(addr);
548 const uint8_t *p = const_sme_p(addr);
549 uint16_t svl_b = arm_sme_svl_b();
550
551 arm_sme_state_t sme_state;
552 sme_state.__svcr = 0x3;
553 sme_state.__svl_b = svl_b;
554 sme_state.__tpidr2_el0 = get_sme_tpidr2_el0(addr);
555
556 arm_sme_za_state_t za_state;
557 memcpy(za_state.__za, za, svl_b * svl_b);
558
559 arm_sve_z_state_t z_state1, z_state2;
560 size_t z_elem_size = svl_b;
561 for (int i = 0; i < 16; i++) {
562 memcpy(z_state1.__z[i], z, z_elem_size);
563 z += z_elem_size;
564 }
565 for (int i = 0; i < 16; i++) {
566 memcpy(z_state2.__z[i], z, z_elem_size);
567 z += z_elem_size;
568 }
569
570 arm_sve_p_state_t p_state;
571 size_t p_elem_size = svl_b / 8;
572 for (int i = 0; i < 16; i++) {
573 memcpy(p_state.__p[i], p, p_elem_size);
574 p += p_elem_size;
575 }
576
577 kern_return_t err = thread_set_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, ARM_SME_STATE_COUNT);
578 if (err) {
579 return err;
580 }
581
582 err = thread_set_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, ARM_SVE_Z_STATE_COUNT);
583 if (err) {
584 return err;
585 }
586
587 err = thread_set_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, ARM_SVE_Z_STATE_COUNT);
588 if (err) {
589 return err;
590 }
591
592 err = thread_set_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, ARM_SVE_P_STATE_COUNT);
593 if (err) {
594 return err;
595 }
596
597 err = thread_set_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, ARM_SME_ZA_STATE_COUNT);
598 if (err) {
599 return err;
600 }
601
602 if (sme_zt0_size()) {
603 const uint8_t *zt0 = const_sme_zt0(addr);
604
605 arm_sme2_state_t sme2_state;
606 memcpy(sme2_state.__zt0, zt0, sizeof(sme2_state.__zt0));
607
608 err = thread_set_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, ARM_SME2_STATE_COUNT);
609 if (err) {
610 return err;
611 }
612 }
613
614 return KERN_SUCCESS;
615 }
616
617 const struct arm_matrix_operations sme_operations = {
618 .name = "SME",
619
620 .data_size = sme_data_size,
621 .alloc_data = sme_alloc_data,
622
623 .is_available = sme_is_available,
624 .start = sme_start,
625 .stop = sme_stop,
626
627 .load_one_vector = sme_load_one_vector,
628 .load_data = sme_load_data,
629 .store_data = sme_store_data,
630
631 .thread_get_state = sme_thread_get_state,
632 .thread_set_state = sme_thread_set_state,
633 };
634
635 const struct arm_matrix_operations sme_za_operations = {
636 .name = "SME (SVCR.ZA only)",
637
638 .data_size = sme_za_data_size,
639 .alloc_data = sme_za_alloc_data,
640
641 .is_available = sme_is_available,
642 .start = sme_za_start,
643 .stop = sme_za_stop,
644
645 .load_one_vector = NULL, /* currently unused */
646 .load_data = sme_za_load_data,
647 .store_data = sme_za_store_data,
648
649 .thread_get_state = NULL, /* currently unused */
650 .thread_set_state = NULL, /* currently unused */
651 };
652