xref: /xnu-11215.61.5/tests/arm_matrix_sme.c (revision 4f1223e81cd707a65cc109d0b8ad6653699da3c4)
1 /*
2  * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29 #include <stdint.h>
30 #include <stdlib.h>
31 #include <sys/sysctl.h>
32 
33 #include "arm_matrix.h"
34 
35 const static unsigned int SME_Z_VECTORS = 32;
36 const static unsigned int SME_P_VECTORS = 16;
37 
38 static unsigned int
sme_version(void)39 sme_version(void)
40 {
41 	static unsigned int ret = 0;
42 	static bool already_read = false;
43 
44 	if (!already_read) {
45 		size_t size = sizeof(unsigned int);
46 		unsigned int feat_sme, feat_sme2;
47 		sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
48 		sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);
49 
50 		if (feat_sme2) {
51 			ret = 2;
52 		} else if (feat_sme) {
53 			ret = 1;
54 		} else {
55 			ret = 0;
56 		}
57 
58 		already_read = true;
59 	}
60 
61 	return ret;
62 }
63 
64 static uint16_t
arm_sme_svl_b(void)65 arm_sme_svl_b(void)
66 {
67 	uint64_t ret = 0;
68 	asm volatile (
69                 "rdsvl	%[ret], #1"
70                 : [ret] "=r"(ret)
71         );
72 	return (uint16_t)ret;
73 }
74 
75 static size_t
sme_za_size(void)76 sme_za_size(void)
77 {
78 	return arm_sme_svl_b() * arm_sme_svl_b();
79 }
80 
81 static size_t
sme_z_size(void)82 sme_z_size(void)
83 {
84 	return arm_sme_svl_b() * SME_Z_VECTORS;
85 }
86 
87 static size_t
sme_p_size(void)88 sme_p_size(void)
89 {
90 	return arm_sme_svl_b() * SME_P_VECTORS / 8;
91 }
92 
93 static size_t
sme_zt0_size(void)94 sme_zt0_size(void)
95 {
96 	if (sme_version() >= 2) {
97 		return 64;
98 	} else {
99 		return 0;
100 	}
101 }
102 
103 static size_t
sme_data_size(void)104 sme_data_size(void)
105 {
106 	return sme_za_size() + sme_z_size() + sme_p_size() + sme_zt0_size();
107 }
108 
109 static void *
sme_alloc_data(void)110 sme_alloc_data(void)
111 {
112 	return malloc(sme_data_size());
113 }
114 
115 static bool
sme_is_available(void)116 sme_is_available(void)
117 {
118 	return sme_version() > 0;
119 }
120 
121 static void
sme_start(void)122 sme_start(void)
123 {
124 	asm volatile ("smstart");
125 }
126 
127 static void
sme_stop(void)128 sme_stop(void)
129 {
130 	asm volatile ("smstop");
131 }
132 
133 static void
sme_load_one_vector(const void * addr)134 sme_load_one_vector(const void *addr)
135 {
136 	asm volatile (
137                 "mov    w12, #0"                "\n"
138                 "ldr    za[w12, #0], [%[addr]]" "\n"
139                 :
140                 : [addr] "r"(addr)
141                 : "w12"
142         );
143 }
144 
145 static void
sme_load_data(const void * addr)146 sme_load_data(const void *addr)
147 {
148 	const uint8_t *za = addr;
149 	const uint8_t *z = za + sme_za_size();
150 	const uint8_t *p = z + sme_z_size();
151 	uint16_t svl_b = arm_sme_svl_b();
152 
153 	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
154 		asm volatile (
155                         "ldr    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
156                         "ldr    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
157                         "ldr    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
158                         "ldr    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
159                         "ldr    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
160                         "ldr    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
161                         "ldr    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
162                         "ldr    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
163                         "ldr    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
164                         "ldr    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
165                         "ldr    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
166                         "ldr    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
167                         "ldr    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
168                         "ldr    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
169                         "ldr    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
170                         "ldr    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
171                         :
172                         : [i] "r"(i),
173                           [addr] "r"(za + (i * svl_b))
174                 );
175 	}
176 
177 	asm volatile (
178                 "ldr    z0, [%[z],   #0, mul vl]"        "\n"
179                 "ldr    z1, [%[z],   #1, mul vl]"        "\n"
180                 "ldr    z2, [%[z],   #2, mul vl]"        "\n"
181                 "ldr    z3, [%[z],   #3, mul vl]"        "\n"
182                 "ldr    z4, [%[z],   #4, mul vl]"        "\n"
183                 "ldr    z5, [%[z],   #5, mul vl]"        "\n"
184                 "ldr    z6, [%[z],   #6, mul vl]"        "\n"
185                 "ldr    z7, [%[z],   #7, mul vl]"        "\n"
186                 "ldr    z8, [%[z],   #8, mul vl]"        "\n"
187                 "ldr    z9, [%[z],   #9, mul vl]"        "\n"
188                 "ldr   z10, [%[z],  #10, mul vl]"        "\n"
189                 "ldr   z11, [%[z],  #11, mul vl]"        "\n"
190                 "ldr   z12, [%[z],  #12, mul vl]"        "\n"
191                 "ldr   z13, [%[z],  #13, mul vl]"        "\n"
192                 "ldr   z14, [%[z],  #14, mul vl]"        "\n"
193                 "ldr   z15, [%[z],  #15, mul vl]"        "\n"
194                 "ldr   z16, [%[z],  #16, mul vl]"        "\n"
195                 "ldr   z17, [%[z],  #17, mul vl]"        "\n"
196                 "ldr   z18, [%[z],  #18, mul vl]"        "\n"
197                 "ldr   z19, [%[z],  #19, mul vl]"        "\n"
198                 "ldr   z20, [%[z],  #20, mul vl]"        "\n"
199                 "ldr   z21, [%[z],  #21, mul vl]"        "\n"
200                 "ldr   z22, [%[z],  #22, mul vl]"        "\n"
201                 "ldr   z23, [%[z],  #23, mul vl]"        "\n"
202                 "ldr   z24, [%[z],  #24, mul vl]"        "\n"
203                 "ldr   z25, [%[z],  #25, mul vl]"        "\n"
204                 "ldr   z26, [%[z],  #26, mul vl]"        "\n"
205                 "ldr   z27, [%[z],  #27, mul vl]"        "\n"
206                 "ldr   z28, [%[z],  #28, mul vl]"        "\n"
207                 "ldr   z29, [%[z],  #29, mul vl]"        "\n"
208                 "ldr   z30, [%[z],  #30, mul vl]"        "\n"
209                 "ldr   z31, [%[z],  #31, mul vl]"        "\n"
210                 :
211                 : [z] "r"(z)
212         );
213 
214 	asm volatile (
215                 "ldr     p0, [%[p],  #0, mul vl]"        "\n"
216                 "ldr     p1, [%[p],  #1, mul vl]"        "\n"
217                 "ldr     p2, [%[p],  #2, mul vl]"        "\n"
218                 "ldr     p3, [%[p],  #3, mul vl]"        "\n"
219                 "ldr     p4, [%[p],  #4, mul vl]"        "\n"
220                 "ldr     p5, [%[p],  #5, mul vl]"        "\n"
221                 "ldr     p6, [%[p],  #6, mul vl]"        "\n"
222                 "ldr     p7, [%[p],  #7, mul vl]"        "\n"
223                 "ldr     p8, [%[p],  #8, mul vl]"        "\n"
224                 "ldr     p9, [%[p],  #9, mul vl]"        "\n"
225                 "ldr    p10, [%[p], #10, mul vl]"        "\n"
226                 "ldr    p11, [%[p], #11, mul vl]"        "\n"
227                 "ldr    p12, [%[p], #12, mul vl]"        "\n"
228                 "ldr    p13, [%[p], #13, mul vl]"        "\n"
229                 "ldr    p14, [%[p], #14, mul vl]"        "\n"
230                 "ldr    p15, [%[p], #15, mul vl]"        "\n"
231                 :
232                 : [p] "r"(p)
233         );
234 
235 	if (sme_zt0_size()) {
236 		const uint8_t *zt0 = p + sme_p_size();
237 		asm volatile (
238                         "ldr	zt0, [%[zt0]]"
239                         :
240                         : [zt0] "r"(zt0)
241                 );
242 	}
243 }
244 
245 static void
sme_store_data(void * addr)246 sme_store_data(void *addr)
247 {
248 	uint8_t *za = addr;
249 	uint8_t *z = za + sme_za_size();
250 	uint8_t *p = z + sme_z_size();
251 	uint16_t svl_b = arm_sme_svl_b();
252 
253 	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
254 		asm volatile (
255                         "str    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
256                         "str    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
257                         "str    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
258                         "str    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
259                         "str    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
260                         "str    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
261                         "str    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
262                         "str    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
263                         "str    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
264                         "str    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
265                         "str    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
266                         "str    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
267                         "str    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
268                         "str    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
269                         "str    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
270                         "str    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
271                         :
272                         : [i] "r"(i),
273                           [addr] "r"(za + (i * svl_b))
274                 );
275 	}
276 
277 	asm volatile (
278                 "str    z0, [%[z],   #0, mul vl]"        "\n"
279                 "str    z1, [%[z],   #1, mul vl]"        "\n"
280                 "str    z2, [%[z],   #2, mul vl]"        "\n"
281                 "str    z3, [%[z],   #3, mul vl]"        "\n"
282                 "str    z4, [%[z],   #4, mul vl]"        "\n"
283                 "str    z5, [%[z],   #5, mul vl]"        "\n"
284                 "str    z6, [%[z],   #6, mul vl]"        "\n"
285                 "str    z7, [%[z],   #7, mul vl]"        "\n"
286                 "str    z8, [%[z],   #8, mul vl]"        "\n"
287                 "str    z9, [%[z],   #9, mul vl]"        "\n"
288                 "str   z10, [%[z],  #10, mul vl]"        "\n"
289                 "str   z11, [%[z],  #11, mul vl]"        "\n"
290                 "str   z12, [%[z],  #12, mul vl]"        "\n"
291                 "str   z13, [%[z],  #13, mul vl]"        "\n"
292                 "str   z14, [%[z],  #14, mul vl]"        "\n"
293                 "str   z15, [%[z],  #15, mul vl]"        "\n"
294                 "str   z16, [%[z],  #16, mul vl]"        "\n"
295                 "str   z17, [%[z],  #17, mul vl]"        "\n"
296                 "str   z18, [%[z],  #18, mul vl]"        "\n"
297                 "str   z19, [%[z],  #19, mul vl]"        "\n"
298                 "str   z20, [%[z],  #20, mul vl]"        "\n"
299                 "str   z21, [%[z],  #21, mul vl]"        "\n"
300                 "str   z22, [%[z],  #22, mul vl]"        "\n"
301                 "str   z23, [%[z],  #23, mul vl]"        "\n"
302                 "str   z24, [%[z],  #24, mul vl]"        "\n"
303                 "str   z25, [%[z],  #25, mul vl]"        "\n"
304                 "str   z26, [%[z],  #26, mul vl]"        "\n"
305                 "str   z27, [%[z],  #27, mul vl]"        "\n"
306                 "str   z28, [%[z],  #28, mul vl]"        "\n"
307                 "str   z29, [%[z],  #29, mul vl]"        "\n"
308                 "str   z30, [%[z],  #30, mul vl]"        "\n"
309                 "str   z31, [%[z],  #31, mul vl]"        "\n"
310                 :
311                 : [z] "r"(z)
312         );
313 
314 	asm volatile (
315                 "str     p0, [%[p],  #0, mul vl]"        "\n"
316                 "str     p1, [%[p],  #1, mul vl]"        "\n"
317                 "str     p2, [%[p],  #2, mul vl]"        "\n"
318                 "str     p3, [%[p],  #3, mul vl]"        "\n"
319                 "str     p4, [%[p],  #4, mul vl]"        "\n"
320                 "str     p5, [%[p],  #5, mul vl]"        "\n"
321                 "str     p6, [%[p],  #6, mul vl]"        "\n"
322                 "str     p7, [%[p],  #7, mul vl]"        "\n"
323                 "str     p8, [%[p],  #8, mul vl]"        "\n"
324                 "str     p9, [%[p],  #9, mul vl]"        "\n"
325                 "str    p10, [%[p], #10, mul vl]"        "\n"
326                 "str    p11, [%[p], #11, mul vl]"        "\n"
327                 "str    p12, [%[p], #12, mul vl]"        "\n"
328                 "str    p13, [%[p], #13, mul vl]"        "\n"
329                 "str    p14, [%[p], #14, mul vl]"        "\n"
330                 "str    p15, [%[p], #15, mul vl]"        "\n"
331                 :
332                 : [p] "r"(p)
333         );
334 
335 	if (sme_zt0_size()) {
336 		uint8_t *zt0 = p + sme_p_size();
337 		asm volatile (
338                         "str	zt0, [%[zt0]]"
339                         :
340                         : [zt0] "r"(zt0)
341                 );
342 	}
343 }
344 
345 const struct arm_matrix_operations sme_operations = {
346 	.name = "SME",
347 
348 	.data_size = sme_data_size,
349 	.alloc_data = sme_alloc_data,
350 
351 	.is_available = sme_is_available,
352 	.start = sme_start,
353 	.stop = sme_stop,
354 
355 	.load_one_vector = sme_load_one_vector,
356 	.load_data = sme_load_data,
357 	.store_data = sme_store_data,
358 };
359