1 /*
2 * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28
29 #include <mach/thread_act.h>
30 #include <stdint.h>
31 #include <stdlib.h>
32 #include <sys/sysctl.h>
33
34 #include "arm_matrix.h"
35
36 const static unsigned int SME_Z_VECTORS = 32;
37 const static unsigned int SME_P_VECTORS = 16;
38
39 static unsigned int
sme_version(void)40 sme_version(void)
41 {
42 static unsigned int ret = 0;
43 static bool already_read = false;
44
45 if (!already_read) {
46 size_t size = sizeof(unsigned int);
47 unsigned int feat_sme, feat_sme2;
48 sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
49 sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);
50
51 if (feat_sme2) {
52 ret = 2;
53 } else if (feat_sme) {
54 ret = 1;
55 } else {
56 ret = 0;
57 }
58
59 already_read = true;
60 }
61
62 return ret;
63 }
64
65 static uint16_t
arm_sme_svl_b(void)66 arm_sme_svl_b(void)
67 {
68 uint64_t ret = 0;
69 asm volatile (
70 "rdsvl %[ret], #1"
71 : [ret] "=r"(ret)
72 );
73 return (uint16_t)ret;
74 }
75
76 static size_t
sme_za_size(void)77 sme_za_size(void)
78 {
79 return arm_sme_svl_b() * arm_sme_svl_b();
80 }
81
82 static size_t
sme_z_size(void)83 sme_z_size(void)
84 {
85 return arm_sme_svl_b() * SME_Z_VECTORS;
86 }
87
88 static size_t
sme_p_size(void)89 sme_p_size(void)
90 {
91 return arm_sme_svl_b() * SME_P_VECTORS / 8;
92 }
93
94 static size_t
sme_zt0_size(void)95 sme_zt0_size(void)
96 {
97 if (sme_version() >= 2) {
98 return 64;
99 } else {
100 return 0;
101 }
102 }
103
104 static size_t
sme_tpidr2_size(void)105 sme_tpidr2_size(void)
106 {
107 return sizeof(uint64_t);
108 }
109
110 static inline uint8_t *
sme_za(void * addr)111 sme_za(void *addr)
112 {
113 return addr;
114 }
115
116 static inline const uint8_t *
const_sme_za(const void * addr)117 const_sme_za(const void *addr)
118 {
119 return addr;
120 }
121
122 static inline uint8_t *
sme_z(void * addr)123 sme_z(void *addr)
124 {
125 return sme_za(addr) + sme_za_size();
126 }
127
128 static inline const uint8_t *
const_sme_z(const void * addr)129 const_sme_z(const void *addr)
130 {
131 return const_sme_za(addr) + sme_za_size();
132 }
133
134 static inline uint8_t *
sme_p(void * addr)135 sme_p(void *addr)
136 {
137 return sme_z(addr) + sme_z_size();
138 }
139
140 static inline const uint8_t *
const_sme_p(const void * addr)141 const_sme_p(const void *addr)
142 {
143 return const_sme_z(addr) + sme_z_size();
144 }
145
146 static inline uint8_t *
sme_zt0(void * addr)147 sme_zt0(void *addr)
148 {
149 return sme_p(addr) + sme_p_size();
150 }
151
152 static inline const uint8_t *
const_sme_zt0(const void * addr)153 const_sme_zt0(const void *addr)
154 {
155 return const_sme_p(addr) + sme_p_size();
156 }
157
158 static size_t
sme_data_size(void)159 sme_data_size(void)
160 {
161 return sme_za_size() + sme_z_size() + sme_p_size() + sme_zt0_size() + sme_tpidr2_size();
162 }
163
164 static inline void
set_sme_tpidr2_el0(void * addr,uint64_t val)165 set_sme_tpidr2_el0(void *addr, uint64_t val)
166 {
167 uint64_t *ptr = (uint64_t *)(sme_zt0(addr) + sme_zt0_size());
168 *ptr = val;
169 }
170
171 static inline uint64_t
get_sme_tpidr2_el0(const void * addr)172 get_sme_tpidr2_el0(const void *addr)
173 {
174 const uint64_t *ptr = (const uint64_t *)(const_sme_zt0(addr) + sme_zt0_size());
175 return *ptr;
176 }
177
178 static void *
sme_alloc_data(void)179 sme_alloc_data(void)
180 {
181 return malloc(sme_data_size());
182 }
183
184 static bool
sme_is_available(void)185 sme_is_available(void)
186 {
187 return sme_version() > 0;
188 }
189
190 static void
sme_start(void)191 sme_start(void)
192 {
193 asm volatile ("smstart");
194 }
195
196 static void
sme_stop(void)197 sme_stop(void)
198 {
199 asm volatile ("smstop");
200 }
201
202 static void
sme_load_one_vector(const void * addr)203 sme_load_one_vector(const void *addr)
204 {
205 asm volatile (
206 "mov w12, #0" "\n"
207 "ldr za[w12, #0], [%[addr]]" "\n"
208 :
209 : [addr] "r"(addr)
210 : "w12"
211 );
212 }
213
214 static void
sme_load_data(const void * addr)215 sme_load_data(const void *addr)
216 {
217 const uint8_t *za = const_sme_za(addr);
218 const uint8_t *z = const_sme_z(addr);
219 const uint8_t *p = const_sme_p(addr);
220 uint16_t svl_b = arm_sme_svl_b();
221
222 for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
223 asm volatile (
224 "ldr za[%w[i], #0], [%[addr], #0, mul vl]" "\n"
225 "ldr za[%w[i], #1], [%[addr], #1, mul vl]" "\n"
226 "ldr za[%w[i], #2], [%[addr], #2, mul vl]" "\n"
227 "ldr za[%w[i], #3], [%[addr], #3, mul vl]" "\n"
228 "ldr za[%w[i], #4], [%[addr], #4, mul vl]" "\n"
229 "ldr za[%w[i], #5], [%[addr], #5, mul vl]" "\n"
230 "ldr za[%w[i], #6], [%[addr], #6, mul vl]" "\n"
231 "ldr za[%w[i], #7], [%[addr], #7, mul vl]" "\n"
232 "ldr za[%w[i], #8], [%[addr], #8, mul vl]" "\n"
233 "ldr za[%w[i], #9], [%[addr], #9, mul vl]" "\n"
234 "ldr za[%w[i], #10], [%[addr], #10, mul vl]" "\n"
235 "ldr za[%w[i], #11], [%[addr], #11, mul vl]" "\n"
236 "ldr za[%w[i], #12], [%[addr], #12, mul vl]" "\n"
237 "ldr za[%w[i], #13], [%[addr], #13, mul vl]" "\n"
238 "ldr za[%w[i], #14], [%[addr], #14, mul vl]" "\n"
239 "ldr za[%w[i], #15], [%[addr], #15, mul vl]" "\n"
240 :
241 : [i] "r"(i),
242 [addr] "r"(za + (i * svl_b))
243 );
244 }
245
246 asm volatile (
247 "ldr z0, [%[z], #0, mul vl]" "\n"
248 "ldr z1, [%[z], #1, mul vl]" "\n"
249 "ldr z2, [%[z], #2, mul vl]" "\n"
250 "ldr z3, [%[z], #3, mul vl]" "\n"
251 "ldr z4, [%[z], #4, mul vl]" "\n"
252 "ldr z5, [%[z], #5, mul vl]" "\n"
253 "ldr z6, [%[z], #6, mul vl]" "\n"
254 "ldr z7, [%[z], #7, mul vl]" "\n"
255 "ldr z8, [%[z], #8, mul vl]" "\n"
256 "ldr z9, [%[z], #9, mul vl]" "\n"
257 "ldr z10, [%[z], #10, mul vl]" "\n"
258 "ldr z11, [%[z], #11, mul vl]" "\n"
259 "ldr z12, [%[z], #12, mul vl]" "\n"
260 "ldr z13, [%[z], #13, mul vl]" "\n"
261 "ldr z14, [%[z], #14, mul vl]" "\n"
262 "ldr z15, [%[z], #15, mul vl]" "\n"
263 "ldr z16, [%[z], #16, mul vl]" "\n"
264 "ldr z17, [%[z], #17, mul vl]" "\n"
265 "ldr z18, [%[z], #18, mul vl]" "\n"
266 "ldr z19, [%[z], #19, mul vl]" "\n"
267 "ldr z20, [%[z], #20, mul vl]" "\n"
268 "ldr z21, [%[z], #21, mul vl]" "\n"
269 "ldr z22, [%[z], #22, mul vl]" "\n"
270 "ldr z23, [%[z], #23, mul vl]" "\n"
271 "ldr z24, [%[z], #24, mul vl]" "\n"
272 "ldr z25, [%[z], #25, mul vl]" "\n"
273 "ldr z26, [%[z], #26, mul vl]" "\n"
274 "ldr z27, [%[z], #27, mul vl]" "\n"
275 "ldr z28, [%[z], #28, mul vl]" "\n"
276 "ldr z29, [%[z], #29, mul vl]" "\n"
277 "ldr z30, [%[z], #30, mul vl]" "\n"
278 "ldr z31, [%[z], #31, mul vl]" "\n"
279 :
280 : [z] "r"(z)
281 );
282
283 asm volatile (
284 "ldr p0, [%[p], #0, mul vl]" "\n"
285 "ldr p1, [%[p], #1, mul vl]" "\n"
286 "ldr p2, [%[p], #2, mul vl]" "\n"
287 "ldr p3, [%[p], #3, mul vl]" "\n"
288 "ldr p4, [%[p], #4, mul vl]" "\n"
289 "ldr p5, [%[p], #5, mul vl]" "\n"
290 "ldr p6, [%[p], #6, mul vl]" "\n"
291 "ldr p7, [%[p], #7, mul vl]" "\n"
292 "ldr p8, [%[p], #8, mul vl]" "\n"
293 "ldr p9, [%[p], #9, mul vl]" "\n"
294 "ldr p10, [%[p], #10, mul vl]" "\n"
295 "ldr p11, [%[p], #11, mul vl]" "\n"
296 "ldr p12, [%[p], #12, mul vl]" "\n"
297 "ldr p13, [%[p], #13, mul vl]" "\n"
298 "ldr p14, [%[p], #14, mul vl]" "\n"
299 "ldr p15, [%[p], #15, mul vl]" "\n"
300 :
301 : [p] "r"(p)
302 );
303
304 if (sme_zt0_size()) {
305 const uint8_t *zt0 = const_sme_zt0(addr);
306 asm volatile (
307 "ldr zt0, [%[zt0]]"
308 :
309 : [zt0] "r"(zt0)
310 );
311 }
312
313 __builtin_arm_wsr64("TPIDR2_EL0", get_sme_tpidr2_el0(addr));
314 }
315
316 static void
sme_store_data(void * addr)317 sme_store_data(void *addr)
318 {
319 uint8_t *za = sme_za(addr);
320 uint8_t *z = sme_z(addr);
321 uint8_t *p = sme_p(addr);
322 uint16_t svl_b = arm_sme_svl_b();
323
324 for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
325 asm volatile (
326 "str za[%w[i], #0], [%[addr], #0, mul vl]" "\n"
327 "str za[%w[i], #1], [%[addr], #1, mul vl]" "\n"
328 "str za[%w[i], #2], [%[addr], #2, mul vl]" "\n"
329 "str za[%w[i], #3], [%[addr], #3, mul vl]" "\n"
330 "str za[%w[i], #4], [%[addr], #4, mul vl]" "\n"
331 "str za[%w[i], #5], [%[addr], #5, mul vl]" "\n"
332 "str za[%w[i], #6], [%[addr], #6, mul vl]" "\n"
333 "str za[%w[i], #7], [%[addr], #7, mul vl]" "\n"
334 "str za[%w[i], #8], [%[addr], #8, mul vl]" "\n"
335 "str za[%w[i], #9], [%[addr], #9, mul vl]" "\n"
336 "str za[%w[i], #10], [%[addr], #10, mul vl]" "\n"
337 "str za[%w[i], #11], [%[addr], #11, mul vl]" "\n"
338 "str za[%w[i], #12], [%[addr], #12, mul vl]" "\n"
339 "str za[%w[i], #13], [%[addr], #13, mul vl]" "\n"
340 "str za[%w[i], #14], [%[addr], #14, mul vl]" "\n"
341 "str za[%w[i], #15], [%[addr], #15, mul vl]" "\n"
342 :
343 : [i] "r"(i),
344 [addr] "r"(za + (i * svl_b))
345 );
346 }
347
348 asm volatile (
349 "str z0, [%[z], #0, mul vl]" "\n"
350 "str z1, [%[z], #1, mul vl]" "\n"
351 "str z2, [%[z], #2, mul vl]" "\n"
352 "str z3, [%[z], #3, mul vl]" "\n"
353 "str z4, [%[z], #4, mul vl]" "\n"
354 "str z5, [%[z], #5, mul vl]" "\n"
355 "str z6, [%[z], #6, mul vl]" "\n"
356 "str z7, [%[z], #7, mul vl]" "\n"
357 "str z8, [%[z], #8, mul vl]" "\n"
358 "str z9, [%[z], #9, mul vl]" "\n"
359 "str z10, [%[z], #10, mul vl]" "\n"
360 "str z11, [%[z], #11, mul vl]" "\n"
361 "str z12, [%[z], #12, mul vl]" "\n"
362 "str z13, [%[z], #13, mul vl]" "\n"
363 "str z14, [%[z], #14, mul vl]" "\n"
364 "str z15, [%[z], #15, mul vl]" "\n"
365 "str z16, [%[z], #16, mul vl]" "\n"
366 "str z17, [%[z], #17, mul vl]" "\n"
367 "str z18, [%[z], #18, mul vl]" "\n"
368 "str z19, [%[z], #19, mul vl]" "\n"
369 "str z20, [%[z], #20, mul vl]" "\n"
370 "str z21, [%[z], #21, mul vl]" "\n"
371 "str z22, [%[z], #22, mul vl]" "\n"
372 "str z23, [%[z], #23, mul vl]" "\n"
373 "str z24, [%[z], #24, mul vl]" "\n"
374 "str z25, [%[z], #25, mul vl]" "\n"
375 "str z26, [%[z], #26, mul vl]" "\n"
376 "str z27, [%[z], #27, mul vl]" "\n"
377 "str z28, [%[z], #28, mul vl]" "\n"
378 "str z29, [%[z], #29, mul vl]" "\n"
379 "str z30, [%[z], #30, mul vl]" "\n"
380 "str z31, [%[z], #31, mul vl]" "\n"
381 :
382 : [z] "r"(z)
383 );
384
385 asm volatile (
386 "str p0, [%[p], #0, mul vl]" "\n"
387 "str p1, [%[p], #1, mul vl]" "\n"
388 "str p2, [%[p], #2, mul vl]" "\n"
389 "str p3, [%[p], #3, mul vl]" "\n"
390 "str p4, [%[p], #4, mul vl]" "\n"
391 "str p5, [%[p], #5, mul vl]" "\n"
392 "str p6, [%[p], #6, mul vl]" "\n"
393 "str p7, [%[p], #7, mul vl]" "\n"
394 "str p8, [%[p], #8, mul vl]" "\n"
395 "str p9, [%[p], #9, mul vl]" "\n"
396 "str p10, [%[p], #10, mul vl]" "\n"
397 "str p11, [%[p], #11, mul vl]" "\n"
398 "str p12, [%[p], #12, mul vl]" "\n"
399 "str p13, [%[p], #13, mul vl]" "\n"
400 "str p14, [%[p], #14, mul vl]" "\n"
401 "str p15, [%[p], #15, mul vl]" "\n"
402 :
403 : [p] "r"(p)
404 );
405
406 if (sme_zt0_size()) {
407 uint8_t *zt0 = sme_zt0(addr);
408 asm volatile (
409 "str zt0, [%[zt0]]"
410 :
411 : [zt0] "r"(zt0)
412 );
413 }
414
415 set_sme_tpidr2_el0(addr, __builtin_arm_rsr64("TPIDR2_EL0"));
416 }
417
418 static kern_return_t
sme_thread_get_state(thread_act_t thread,void * addr)419 sme_thread_get_state(thread_act_t thread, void *addr)
420 {
421 uint8_t *za = sme_za(addr);
422 uint8_t *z = sme_z(addr);
423 uint8_t *p = sme_p(addr);
424 uint16_t svl_b = arm_sme_svl_b();
425
426 arm_sme_state_t sme_state;
427 mach_msg_type_number_t sme_count = ARM_SME_STATE_COUNT;
428 kern_return_t err = thread_get_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, &sme_count);
429 if (err) {
430 return err;
431 }
432 set_sme_tpidr2_el0(addr, sme_state.__tpidr2_el0);
433
434 arm_sme_za_state_t za_state;
435 mach_msg_type_number_t za_count = ARM_SME_ZA_STATE_COUNT;
436 err = thread_get_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, &za_count);
437 if (err) {
438 return err;
439 }
440
441 arm_sve_z_state_t z_state1, z_state2;
442 mach_msg_type_number_t z_streaming_count = ARM_SVE_Z_STATE_COUNT;
443 err = thread_get_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, &z_streaming_count);
444 if (err) {
445 return err;
446 }
447 err = thread_get_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, &z_streaming_count);
448 if (err) {
449 return err;
450 }
451
452 arm_sve_p_state_t p_state;
453 mach_msg_type_number_t p_streaming_count = ARM_SVE_P_STATE_COUNT;
454 err = thread_get_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, &p_streaming_count);
455 if (err) {
456 return err;
457 }
458
459 memcpy(za, za_state.__za, svl_b * svl_b);
460
461 size_t z_elem_size = svl_b;
462 for (int i = 0; i < 16; i++) {
463 memcpy(z, z_state1.__z[i], z_elem_size);
464 z += z_elem_size;
465 }
466 for (int i = 0; i < 16; i++) {
467 memcpy(z, z_state2.__z[i], z_elem_size);
468 z += z_elem_size;
469 }
470
471 size_t p_elem_size = svl_b / 8;
472 for (int i = 0; i < 16; i++) {
473 memcpy(p, p_state.__p[i], p_elem_size);
474 p += p_elem_size;
475 }
476
477 if (sme_zt0_size()) {
478 uint8_t *zt0 = sme_zt0(addr);
479
480 arm_sme2_state_t sme2_state;
481 mach_msg_type_number_t sme2_count = ARM_SME2_STATE_COUNT;
482 err = thread_get_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, &sme2_count);
483 if (err) {
484 return err;
485 }
486
487 memcpy(zt0, sme2_state.__zt0, sizeof(sme2_state.__zt0));
488 }
489
490 return KERN_SUCCESS;
491 }
492
493 static kern_return_t
sme_thread_set_state(thread_act_t thread,const void * addr)494 sme_thread_set_state(thread_act_t thread, const void *addr)
495 {
496 const uint8_t *za = const_sme_za(addr);
497 const uint8_t *z = const_sme_z(addr);
498 const uint8_t *p = const_sme_p(addr);
499 uint16_t svl_b = arm_sme_svl_b();
500
501 arm_sme_state_t sme_state;
502 sme_state.__svcr = 0x3;
503 sme_state.__svl_b = svl_b;
504 sme_state.__tpidr2_el0 = get_sme_tpidr2_el0(addr);
505
506 arm_sme_za_state_t za_state;
507 memcpy(za_state.__za, za, svl_b * svl_b);
508
509 arm_sve_z_state_t z_state1, z_state2;
510 size_t z_elem_size = svl_b;
511 for (int i = 0; i < 16; i++) {
512 memcpy(z_state1.__z[i], z, z_elem_size);
513 z += z_elem_size;
514 }
515 for (int i = 0; i < 16; i++) {
516 memcpy(z_state2.__z[i], z, z_elem_size);
517 z += z_elem_size;
518 }
519
520 arm_sve_p_state_t p_state;
521 size_t p_elem_size = svl_b / 8;
522 for (int i = 0; i < 16; i++) {
523 memcpy(p_state.__p[i], p, p_elem_size);
524 p += p_elem_size;
525 }
526
527 kern_return_t err = thread_set_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, ARM_SME_STATE_COUNT);
528 if (err) {
529 return err;
530 }
531
532 err = thread_set_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, ARM_SVE_Z_STATE_COUNT);
533 if (err) {
534 return err;
535 }
536
537 err = thread_set_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, ARM_SVE_Z_STATE_COUNT);
538 if (err) {
539 return err;
540 }
541
542 err = thread_set_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, ARM_SVE_P_STATE_COUNT);
543 if (err) {
544 return err;
545 }
546
547 err = thread_set_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, ARM_SME_ZA_STATE_COUNT);
548 if (err) {
549 return err;
550 }
551
552 if (sme_zt0_size()) {
553 const uint8_t *zt0 = const_sme_zt0(addr);
554
555 arm_sme2_state_t sme2_state;
556 memcpy(sme2_state.__zt0, zt0, sizeof(sme2_state.__zt0));
557
558 err = thread_set_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, ARM_SME2_STATE_COUNT);
559 if (err) {
560 return err;
561 }
562 }
563
564 return KERN_SUCCESS;
565 }
566
567 const struct arm_matrix_operations sme_operations = {
568 .name = "SME",
569
570 .data_size = sme_data_size,
571 .alloc_data = sme_alloc_data,
572
573 .is_available = sme_is_available,
574 .start = sme_start,
575 .stop = sme_stop,
576
577 .load_one_vector = sme_load_one_vector,
578 .load_data = sme_load_data,
579 .store_data = sme_store_data,
580
581 .thread_get_state = sme_thread_get_state,
582 .thread_set_state = sme_thread_set_state,
583 };
584