1 /*
2 * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28
29 #include <stdint.h>
30 #include <stdlib.h>
31 #include <sys/sysctl.h>
32
33 #include "arm_matrix.h"
34
35 const static unsigned int SME_Z_VECTORS = 32;
36 const static unsigned int SME_P_VECTORS = 16;
37
38 static unsigned int
sme_version(void)39 sme_version(void)
40 {
41 static unsigned int ret = 0;
42 static bool already_read = false;
43
44 if (!already_read) {
45 size_t size = sizeof(unsigned int);
46 unsigned int feat_sme, feat_sme2;
47 sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
48 sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);
49
50 if (feat_sme2) {
51 ret = 2;
52 } else if (feat_sme) {
53 ret = 1;
54 } else {
55 ret = 0;
56 }
57
58 already_read = true;
59 }
60
61 return ret;
62 }
63
64 static uint16_t
arm_sme_svl_b(void)65 arm_sme_svl_b(void)
66 {
67 uint64_t ret = 0;
68 asm volatile (
69 "rdsvl %[ret], #1"
70 : [ret] "=r"(ret)
71 );
72 return (uint16_t)ret;
73 }
74
75 static size_t
sme_za_size(void)76 sme_za_size(void)
77 {
78 return arm_sme_svl_b() * arm_sme_svl_b();
79 }
80
81 static size_t
sme_z_size(void)82 sme_z_size(void)
83 {
84 return arm_sme_svl_b() * SME_Z_VECTORS;
85 }
86
87 static size_t
sme_p_size(void)88 sme_p_size(void)
89 {
90 return arm_sme_svl_b() * SME_P_VECTORS / 8;
91 }
92
93 static size_t
sme_zt0_size(void)94 sme_zt0_size(void)
95 {
96 if (sme_version() >= 2) {
97 return 64;
98 } else {
99 return 0;
100 }
101 }
102
103 static size_t
sme_data_size(void)104 sme_data_size(void)
105 {
106 return sme_za_size() + sme_z_size() + sme_p_size() + sme_zt0_size();
107 }
108
109 static void *
sme_alloc_data(void)110 sme_alloc_data(void)
111 {
112 return malloc(sme_data_size());
113 }
114
115 static bool
sme_is_available(void)116 sme_is_available(void)
117 {
118 return sme_version() > 0;
119 }
120
121 static void
sme_start(void)122 sme_start(void)
123 {
124 asm volatile ("smstart");
125 }
126
127 static void
sme_stop(void)128 sme_stop(void)
129 {
130 asm volatile ("smstop");
131 }
132
133 static void
sme_load_one_vector(const void * addr)134 sme_load_one_vector(const void *addr)
135 {
136 asm volatile (
137 "mov w12, #0" "\n"
138 "ldr za[w12, #0], [%[addr]]" "\n"
139 :
140 : [addr] "r"(addr)
141 : "w12"
142 );
143 }
144
145 static void
sme_load_data(const void * addr)146 sme_load_data(const void *addr)
147 {
148 const uint8_t *za = addr;
149 const uint8_t *z = za + sme_za_size();
150 const uint8_t *p = z + sme_z_size();
151 uint16_t svl_b = arm_sme_svl_b();
152
153 for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
154 asm volatile (
155 "ldr za[%w[i], #0], [%[addr], #0, mul vl]" "\n"
156 "ldr za[%w[i], #1], [%[addr], #1, mul vl]" "\n"
157 "ldr za[%w[i], #2], [%[addr], #2, mul vl]" "\n"
158 "ldr za[%w[i], #3], [%[addr], #3, mul vl]" "\n"
159 "ldr za[%w[i], #4], [%[addr], #4, mul vl]" "\n"
160 "ldr za[%w[i], #5], [%[addr], #5, mul vl]" "\n"
161 "ldr za[%w[i], #6], [%[addr], #6, mul vl]" "\n"
162 "ldr za[%w[i], #7], [%[addr], #7, mul vl]" "\n"
163 "ldr za[%w[i], #8], [%[addr], #8, mul vl]" "\n"
164 "ldr za[%w[i], #9], [%[addr], #9, mul vl]" "\n"
165 "ldr za[%w[i], #10], [%[addr], #10, mul vl]" "\n"
166 "ldr za[%w[i], #11], [%[addr], #11, mul vl]" "\n"
167 "ldr za[%w[i], #12], [%[addr], #12, mul vl]" "\n"
168 "ldr za[%w[i], #13], [%[addr], #13, mul vl]" "\n"
169 "ldr za[%w[i], #14], [%[addr], #14, mul vl]" "\n"
170 "ldr za[%w[i], #15], [%[addr], #15, mul vl]" "\n"
171 :
172 : [i] "r"(i),
173 [addr] "r"(za + (i * svl_b))
174 );
175 }
176
177 asm volatile (
178 "ldr z0, [%[z], #0, mul vl]" "\n"
179 "ldr z1, [%[z], #1, mul vl]" "\n"
180 "ldr z2, [%[z], #2, mul vl]" "\n"
181 "ldr z3, [%[z], #3, mul vl]" "\n"
182 "ldr z4, [%[z], #4, mul vl]" "\n"
183 "ldr z5, [%[z], #5, mul vl]" "\n"
184 "ldr z6, [%[z], #6, mul vl]" "\n"
185 "ldr z7, [%[z], #7, mul vl]" "\n"
186 "ldr z8, [%[z], #8, mul vl]" "\n"
187 "ldr z9, [%[z], #9, mul vl]" "\n"
188 "ldr z10, [%[z], #10, mul vl]" "\n"
189 "ldr z11, [%[z], #11, mul vl]" "\n"
190 "ldr z12, [%[z], #12, mul vl]" "\n"
191 "ldr z13, [%[z], #13, mul vl]" "\n"
192 "ldr z14, [%[z], #14, mul vl]" "\n"
193 "ldr z15, [%[z], #15, mul vl]" "\n"
194 "ldr z16, [%[z], #16, mul vl]" "\n"
195 "ldr z17, [%[z], #17, mul vl]" "\n"
196 "ldr z18, [%[z], #18, mul vl]" "\n"
197 "ldr z19, [%[z], #19, mul vl]" "\n"
198 "ldr z20, [%[z], #20, mul vl]" "\n"
199 "ldr z21, [%[z], #21, mul vl]" "\n"
200 "ldr z22, [%[z], #22, mul vl]" "\n"
201 "ldr z23, [%[z], #23, mul vl]" "\n"
202 "ldr z24, [%[z], #24, mul vl]" "\n"
203 "ldr z25, [%[z], #25, mul vl]" "\n"
204 "ldr z26, [%[z], #26, mul vl]" "\n"
205 "ldr z27, [%[z], #27, mul vl]" "\n"
206 "ldr z28, [%[z], #28, mul vl]" "\n"
207 "ldr z29, [%[z], #29, mul vl]" "\n"
208 "ldr z30, [%[z], #30, mul vl]" "\n"
209 "ldr z31, [%[z], #31, mul vl]" "\n"
210 :
211 : [z] "r"(z)
212 );
213
214 asm volatile (
215 "ldr p0, [%[p], #0, mul vl]" "\n"
216 "ldr p1, [%[p], #1, mul vl]" "\n"
217 "ldr p2, [%[p], #2, mul vl]" "\n"
218 "ldr p3, [%[p], #3, mul vl]" "\n"
219 "ldr p4, [%[p], #4, mul vl]" "\n"
220 "ldr p5, [%[p], #5, mul vl]" "\n"
221 "ldr p6, [%[p], #6, mul vl]" "\n"
222 "ldr p7, [%[p], #7, mul vl]" "\n"
223 "ldr p8, [%[p], #8, mul vl]" "\n"
224 "ldr p9, [%[p], #9, mul vl]" "\n"
225 "ldr p10, [%[p], #10, mul vl]" "\n"
226 "ldr p11, [%[p], #11, mul vl]" "\n"
227 "ldr p12, [%[p], #12, mul vl]" "\n"
228 "ldr p13, [%[p], #13, mul vl]" "\n"
229 "ldr p14, [%[p], #14, mul vl]" "\n"
230 "ldr p15, [%[p], #15, mul vl]" "\n"
231 :
232 : [p] "r"(p)
233 );
234
235 if (sme_zt0_size()) {
236 const uint8_t *zt0 = p + sme_p_size();
237 asm volatile (
238 "ldr zt0, [%[zt0]]"
239 :
240 : [zt0] "r"(zt0)
241 );
242 }
243 }
244
245 static void
sme_store_data(void * addr)246 sme_store_data(void *addr)
247 {
248 uint8_t *za = addr;
249 uint8_t *z = za + sme_za_size();
250 uint8_t *p = z + sme_z_size();
251 uint16_t svl_b = arm_sme_svl_b();
252
253 for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
254 asm volatile (
255 "str za[%w[i], #0], [%[addr], #0, mul vl]" "\n"
256 "str za[%w[i], #1], [%[addr], #1, mul vl]" "\n"
257 "str za[%w[i], #2], [%[addr], #2, mul vl]" "\n"
258 "str za[%w[i], #3], [%[addr], #3, mul vl]" "\n"
259 "str za[%w[i], #4], [%[addr], #4, mul vl]" "\n"
260 "str za[%w[i], #5], [%[addr], #5, mul vl]" "\n"
261 "str za[%w[i], #6], [%[addr], #6, mul vl]" "\n"
262 "str za[%w[i], #7], [%[addr], #7, mul vl]" "\n"
263 "str za[%w[i], #8], [%[addr], #8, mul vl]" "\n"
264 "str za[%w[i], #9], [%[addr], #9, mul vl]" "\n"
265 "str za[%w[i], #10], [%[addr], #10, mul vl]" "\n"
266 "str za[%w[i], #11], [%[addr], #11, mul vl]" "\n"
267 "str za[%w[i], #12], [%[addr], #12, mul vl]" "\n"
268 "str za[%w[i], #13], [%[addr], #13, mul vl]" "\n"
269 "str za[%w[i], #14], [%[addr], #14, mul vl]" "\n"
270 "str za[%w[i], #15], [%[addr], #15, mul vl]" "\n"
271 :
272 : [i] "r"(i),
273 [addr] "r"(za + (i * svl_b))
274 );
275 }
276
277 asm volatile (
278 "str z0, [%[z], #0, mul vl]" "\n"
279 "str z1, [%[z], #1, mul vl]" "\n"
280 "str z2, [%[z], #2, mul vl]" "\n"
281 "str z3, [%[z], #3, mul vl]" "\n"
282 "str z4, [%[z], #4, mul vl]" "\n"
283 "str z5, [%[z], #5, mul vl]" "\n"
284 "str z6, [%[z], #6, mul vl]" "\n"
285 "str z7, [%[z], #7, mul vl]" "\n"
286 "str z8, [%[z], #8, mul vl]" "\n"
287 "str z9, [%[z], #9, mul vl]" "\n"
288 "str z10, [%[z], #10, mul vl]" "\n"
289 "str z11, [%[z], #11, mul vl]" "\n"
290 "str z12, [%[z], #12, mul vl]" "\n"
291 "str z13, [%[z], #13, mul vl]" "\n"
292 "str z14, [%[z], #14, mul vl]" "\n"
293 "str z15, [%[z], #15, mul vl]" "\n"
294 "str z16, [%[z], #16, mul vl]" "\n"
295 "str z17, [%[z], #17, mul vl]" "\n"
296 "str z18, [%[z], #18, mul vl]" "\n"
297 "str z19, [%[z], #19, mul vl]" "\n"
298 "str z20, [%[z], #20, mul vl]" "\n"
299 "str z21, [%[z], #21, mul vl]" "\n"
300 "str z22, [%[z], #22, mul vl]" "\n"
301 "str z23, [%[z], #23, mul vl]" "\n"
302 "str z24, [%[z], #24, mul vl]" "\n"
303 "str z25, [%[z], #25, mul vl]" "\n"
304 "str z26, [%[z], #26, mul vl]" "\n"
305 "str z27, [%[z], #27, mul vl]" "\n"
306 "str z28, [%[z], #28, mul vl]" "\n"
307 "str z29, [%[z], #29, mul vl]" "\n"
308 "str z30, [%[z], #30, mul vl]" "\n"
309 "str z31, [%[z], #31, mul vl]" "\n"
310 :
311 : [z] "r"(z)
312 );
313
314 asm volatile (
315 "str p0, [%[p], #0, mul vl]" "\n"
316 "str p1, [%[p], #1, mul vl]" "\n"
317 "str p2, [%[p], #2, mul vl]" "\n"
318 "str p3, [%[p], #3, mul vl]" "\n"
319 "str p4, [%[p], #4, mul vl]" "\n"
320 "str p5, [%[p], #5, mul vl]" "\n"
321 "str p6, [%[p], #6, mul vl]" "\n"
322 "str p7, [%[p], #7, mul vl]" "\n"
323 "str p8, [%[p], #8, mul vl]" "\n"
324 "str p9, [%[p], #9, mul vl]" "\n"
325 "str p10, [%[p], #10, mul vl]" "\n"
326 "str p11, [%[p], #11, mul vl]" "\n"
327 "str p12, [%[p], #12, mul vl]" "\n"
328 "str p13, [%[p], #13, mul vl]" "\n"
329 "str p14, [%[p], #14, mul vl]" "\n"
330 "str p15, [%[p], #15, mul vl]" "\n"
331 :
332 : [p] "r"(p)
333 );
334
335 if (sme_zt0_size()) {
336 uint8_t *zt0 = p + sme_p_size();
337 asm volatile (
338 "str zt0, [%[zt0]]"
339 :
340 : [zt0] "r"(zt0)
341 );
342 }
343 }
344
345 const struct arm_matrix_operations sme_operations = {
346 .name = "SME",
347
348 .data_size = sme_data_size,
349 .alloc_data = sme_alloc_data,
350
351 .is_available = sme_is_available,
352 .start = sme_start,
353 .stop = sme_stop,
354
355 .load_one_vector = sme_load_one_vector,
356 .load_data = sme_load_data,
357 .store_data = sme_store_data,
358 };
359