1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2021 ARM Limited.
4 */
5
6 #include <errno.h>
7 #include <stdbool.h>
8 #include <stddef.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/prctl.h>
15 #include <asm/hwcap.h>
16 #include <asm/sigcontext.h>
17 #include <asm/unistd.h>
18
19 #include "../../kselftest.h"
20
21 #include "syscall-abi.h"
22
23 static int default_sme_vl;
24
25 static int sve_vl_count;
26 static unsigned int sve_vls[SVE_VQ_MAX];
27 static int sme_vl_count;
28 static unsigned int sme_vls[SVE_VQ_MAX];
29
30 extern void do_syscall(int sve_vl, int sme_vl);
31
fill_random(void * buf,size_t size)32 static void fill_random(void *buf, size_t size)
33 {
34 int i;
35 uint32_t *lbuf = buf;
36
37 /* random() returns a 32 bit number regardless of the size of long */
38 for (i = 0; i < size / sizeof(uint32_t); i++)
39 lbuf[i] = random();
40 }
41
42 /*
43 * We also repeat the test for several syscalls to try to expose different
44 * behaviour.
45 */
46 static struct syscall_cfg {
47 int syscall_nr;
48 const char *name;
49 } syscalls[] = {
50 { __NR_getpid, "getpid()" },
51 { __NR_sched_yield, "sched_yield()" },
52 };
53
54 #define NUM_GPR 31
55 uint64_t gpr_in[NUM_GPR];
56 uint64_t gpr_out[NUM_GPR];
57
setup_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)58 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
59 uint64_t svcr)
60 {
61 fill_random(gpr_in, sizeof(gpr_in));
62 gpr_in[8] = cfg->syscall_nr;
63 memset(gpr_out, 0, sizeof(gpr_out));
64 }
65
check_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)66 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
67 {
68 int errors = 0;
69 int i;
70
71 /*
72 * GPR x0-x7 may be clobbered, and all others should be preserved.
73 */
74 for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
75 if (gpr_in[i] != gpr_out[i]) {
76 ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
77 cfg->name, sve_vl, i,
78 gpr_in[i], gpr_out[i]);
79 errors++;
80 }
81 }
82
83 return errors;
84 }
85
86 #define NUM_FPR 32
87 uint64_t fpr_in[NUM_FPR * 2];
88 uint64_t fpr_out[NUM_FPR * 2];
89 uint64_t fpr_zero[NUM_FPR * 2];
90
setup_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)91 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
92 uint64_t svcr)
93 {
94 fill_random(fpr_in, sizeof(fpr_in));
95 memset(fpr_out, 0, sizeof(fpr_out));
96 }
97
check_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)98 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
99 uint64_t svcr)
100 {
101 int errors = 0;
102 int i;
103
104 if (!sve_vl && !(svcr & SVCR_SM_MASK)) {
105 for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
106 if (fpr_in[i] != fpr_out[i]) {
107 ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
108 cfg->name,
109 i / 2, i % 2,
110 fpr_in[i], fpr_out[i]);
111 errors++;
112 }
113 }
114 }
115
116 /*
117 * In streaming mode the whole register set should be cleared
118 * by the transition out of streaming mode.
119 */
120 if (svcr & SVCR_SM_MASK) {
121 if (memcmp(fpr_zero, fpr_out, sizeof(fpr_out)) != 0) {
122 ksft_print_msg("%s FPSIMD registers non-zero exiting SM\n",
123 cfg->name);
124 errors++;
125 }
126 }
127
128 return errors;
129 }
130
131 #define SVE_Z_SHARED_BYTES (128 / 8)
132
133 static uint8_t z_zero[__SVE_ZREG_SIZE(SVE_VQ_MAX)];
134 uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
135 uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
136
setup_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)137 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
138 uint64_t svcr)
139 {
140 fill_random(z_in, sizeof(z_in));
141 fill_random(z_out, sizeof(z_out));
142 }
143
check_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)144 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
145 uint64_t svcr)
146 {
147 size_t reg_size = sve_vl;
148 int errors = 0;
149 int i;
150
151 if (!sve_vl)
152 return 0;
153
154 for (i = 0; i < SVE_NUM_ZREGS; i++) {
155 uint8_t *in = &z_in[reg_size * i];
156 uint8_t *out = &z_out[reg_size * i];
157
158 if (svcr & SVCR_SM_MASK) {
159 /*
160 * In streaming mode the whole register should
161 * be cleared by the transition out of
162 * streaming mode.
163 */
164 if (memcmp(z_zero, out, reg_size) != 0) {
165 ksft_print_msg("%s SVE VL %d Z%d non-zero\n",
166 cfg->name, sve_vl, i);
167 errors++;
168 }
169 } else {
170 /*
171 * For standard SVE the low 128 bits should be
172 * preserved and any additional bits cleared.
173 */
174 if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) {
175 ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
176 cfg->name, sve_vl, i);
177 errors++;
178 }
179
180 if (reg_size > SVE_Z_SHARED_BYTES &&
181 (memcmp(z_zero, out + SVE_Z_SHARED_BYTES,
182 reg_size - SVE_Z_SHARED_BYTES) != 0)) {
183 ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n",
184 cfg->name, sve_vl, i);
185 errors++;
186 }
187 }
188 }
189
190 return errors;
191 }
192
193 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
194 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
195
setup_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)196 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
197 uint64_t svcr)
198 {
199 fill_random(p_in, sizeof(p_in));
200 fill_random(p_out, sizeof(p_out));
201 }
202
check_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)203 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
204 uint64_t svcr)
205 {
206 size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
207
208 int errors = 0;
209 int i;
210
211 if (!sve_vl)
212 return 0;
213
214 /* After a syscall the P registers should be zeroed */
215 for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
216 if (p_out[i])
217 errors++;
218 if (errors)
219 ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
220 cfg->name, sve_vl);
221
222 return errors;
223 }
224
225 uint8_t ffr_in[__SVE_PREG_SIZE(SVE_VQ_MAX)];
226 uint8_t ffr_out[__SVE_PREG_SIZE(SVE_VQ_MAX)];
227
setup_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)228 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
229 uint64_t svcr)
230 {
231 /*
232 * If we are in streaming mode and do not have FA64 then FFR
233 * is unavailable.
234 */
235 if ((svcr & SVCR_SM_MASK) &&
236 !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
237 memset(&ffr_in, 0, sizeof(ffr_in));
238 return;
239 }
240
241 /*
242 * It is only valid to set a contiguous set of bits starting
243 * at 0. For now since we're expecting this to be cleared by
244 * a syscall just set all bits.
245 */
246 memset(ffr_in, 0xff, sizeof(ffr_in));
247 fill_random(ffr_out, sizeof(ffr_out));
248 }
249
check_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)250 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
251 uint64_t svcr)
252 {
253 size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
254 int errors = 0;
255 int i;
256
257 if (!sve_vl)
258 return 0;
259
260 if ((svcr & SVCR_SM_MASK) &&
261 !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
262 return 0;
263
264 /* After a syscall FFR should be zeroed */
265 for (i = 0; i < reg_size; i++)
266 if (ffr_out[i])
267 errors++;
268 if (errors)
269 ksft_print_msg("%s SVE VL %d FFR non-zero\n",
270 cfg->name, sve_vl);
271
272 return errors;
273 }
274
275 uint64_t svcr_in, svcr_out;
276
setup_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)277 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
278 uint64_t svcr)
279 {
280 svcr_in = svcr;
281 }
282
check_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)283 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
284 uint64_t svcr)
285 {
286 int errors = 0;
287
288 if (svcr_out & SVCR_SM_MASK) {
289 ksft_print_msg("%s Still in SM, SVCR %llx\n",
290 cfg->name, svcr_out);
291 errors++;
292 }
293
294 if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
295 ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
296 cfg->name, svcr_in, svcr_out);
297 errors++;
298 }
299
300 return errors;
301 }
302
303 uint8_t za_in[ZA_SIG_REGS_SIZE(SVE_VQ_MAX)];
304 uint8_t za_out[ZA_SIG_REGS_SIZE(SVE_VQ_MAX)];
305
setup_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)306 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
307 uint64_t svcr)
308 {
309 fill_random(za_in, sizeof(za_in));
310 memset(za_out, 0, sizeof(za_out));
311 }
312
check_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)313 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
314 uint64_t svcr)
315 {
316 size_t reg_size = sme_vl * sme_vl;
317 int errors = 0;
318
319 if (!(svcr & SVCR_ZA_MASK))
320 return 0;
321
322 if (memcmp(za_in, za_out, reg_size) != 0) {
323 ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
324 errors++;
325 }
326
327 return errors;
328 }
329
330 uint8_t zt_in[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
331 uint8_t zt_out[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
332
setup_zt(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)333 static void setup_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
334 uint64_t svcr)
335 {
336 fill_random(zt_in, sizeof(zt_in));
337 memset(zt_out, 0, sizeof(zt_out));
338 }
339
check_zt(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)340 static int check_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
341 uint64_t svcr)
342 {
343 int errors = 0;
344
345 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2))
346 return 0;
347
348 if (!(svcr & SVCR_ZA_MASK))
349 return 0;
350
351 if (memcmp(zt_in, zt_out, sizeof(zt_in)) != 0) {
352 ksft_print_msg("SME VL %d ZT does not match\n", sme_vl);
353 errors++;
354 }
355
356 return errors;
357 }
358
359 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
360 uint64_t svcr);
361 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
362 uint64_t svcr);
363
364 /*
365 * Each set of registers has a setup function which is called before
366 * the syscall to fill values in a global variable for loading by the
367 * test code and a check function which validates that the results are
368 * as expected. Vector lengths are passed everywhere, a vector length
369 * of 0 should be treated as do not test.
370 */
371 static struct {
372 setup_fn setup;
373 check_fn check;
374 } regset[] = {
375 { setup_gpr, check_gpr },
376 { setup_fpr, check_fpr },
377 { setup_z, check_z },
378 { setup_p, check_p },
379 { setup_ffr, check_ffr },
380 { setup_svcr, check_svcr },
381 { setup_za, check_za },
382 { setup_zt, check_zt },
383 };
384
do_test(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)385 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
386 uint64_t svcr)
387 {
388 int errors = 0;
389 int i;
390
391 for (i = 0; i < ARRAY_SIZE(regset); i++)
392 regset[i].setup(cfg, sve_vl, sme_vl, svcr);
393
394 do_syscall(sve_vl, sme_vl);
395
396 for (i = 0; i < ARRAY_SIZE(regset); i++)
397 errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
398
399 return errors == 0;
400 }
401
test_one_syscall(struct syscall_cfg * cfg)402 static void test_one_syscall(struct syscall_cfg *cfg)
403 {
404 int sve, sme;
405 int ret;
406
407 /* FPSIMD only case */
408 ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
409 "%s FPSIMD\n", cfg->name);
410
411 for (sve = 0; sve < sve_vl_count; sve++) {
412 ret = prctl(PR_SVE_SET_VL, sve_vls[sve]);
413 if (ret == -1)
414 ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
415 strerror(errno), errno);
416
417 ksft_test_result(do_test(cfg, sve_vls[sve], default_sme_vl, 0),
418 "%s SVE VL %d\n", cfg->name, sve_vls[sve]);
419
420 for (sme = 0; sme < sme_vl_count; sme++) {
421 ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
422 if (ret == -1)
423 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
424 strerror(errno), errno);
425
426 ksft_test_result(do_test(cfg, sve_vls[sve],
427 sme_vls[sme],
428 SVCR_ZA_MASK | SVCR_SM_MASK),
429 "%s SVE VL %d/SME VL %d SM+ZA\n",
430 cfg->name, sve_vls[sve],
431 sme_vls[sme]);
432 ksft_test_result(do_test(cfg, sve_vls[sve],
433 sme_vls[sme], SVCR_SM_MASK),
434 "%s SVE VL %d/SME VL %d SM\n",
435 cfg->name, sve_vls[sve],
436 sme_vls[sme]);
437 ksft_test_result(do_test(cfg, sve_vls[sve],
438 sme_vls[sme], SVCR_ZA_MASK),
439 "%s SVE VL %d/SME VL %d ZA\n",
440 cfg->name, sve_vls[sve],
441 sme_vls[sme]);
442 }
443 }
444
445 for (sme = 0; sme < sme_vl_count; sme++) {
446 ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
447 if (ret == -1)
448 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
449 strerror(errno), errno);
450
451 ksft_test_result(do_test(cfg, 0, sme_vls[sme],
452 SVCR_ZA_MASK | SVCR_SM_MASK),
453 "%s SME VL %d SM+ZA\n",
454 cfg->name, sme_vls[sme]);
455 ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_SM_MASK),
456 "%s SME VL %d SM\n",
457 cfg->name, sme_vls[sme]);
458 ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_ZA_MASK),
459 "%s SME VL %d ZA\n",
460 cfg->name, sme_vls[sme]);
461 }
462 }
463
sve_count_vls(void)464 void sve_count_vls(void)
465 {
466 unsigned int vq;
467 int vl;
468
469 if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
470 return;
471
472 /*
473 * Enumerate up to SVE_VQ_MAX vector lengths
474 */
475 for (vq = SVE_VQ_MAX; vq > 0; vq /= 2) {
476 vl = prctl(PR_SVE_SET_VL, vq * 16);
477 if (vl == -1)
478 ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
479 strerror(errno), errno);
480
481 vl &= PR_SVE_VL_LEN_MASK;
482
483 if (vq != sve_vq_from_vl(vl))
484 vq = sve_vq_from_vl(vl);
485
486 sve_vls[sve_vl_count++] = vl;
487 }
488 }
489
sme_count_vls(void)490 void sme_count_vls(void)
491 {
492 unsigned int vq;
493 int vl;
494
495 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
496 return;
497
498 /*
499 * Enumerate up to SVE_VQ_MAX vector lengths
500 */
501 for (vq = SVE_VQ_MAX; vq > 0; vq /= 2) {
502 vl = prctl(PR_SME_SET_VL, vq * 16);
503 if (vl == -1)
504 ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
505 strerror(errno), errno);
506
507 vl &= PR_SME_VL_LEN_MASK;
508
509 /* Found lowest VL */
510 if (sve_vq_from_vl(vl) > vq)
511 break;
512
513 if (vq != sve_vq_from_vl(vl))
514 vq = sve_vq_from_vl(vl);
515
516 sme_vls[sme_vl_count++] = vl;
517 }
518
519 /* Ensure we configure a SME VL, used to flag if SVCR is set */
520 default_sme_vl = sme_vls[0];
521 }
522
main(void)523 int main(void)
524 {
525 int i;
526 int tests = 1; /* FPSIMD */
527 int sme_ver;
528
529 srandom(getpid());
530
531 ksft_print_header();
532
533 sve_count_vls();
534 sme_count_vls();
535
536 tests += sve_vl_count;
537 tests += sme_vl_count * 3;
538 tests += (sve_vl_count * sme_vl_count) * 3;
539 ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
540
541 if (getauxval(AT_HWCAP2) & HWCAP2_SME2)
542 sme_ver = 2;
543 else
544 sme_ver = 1;
545
546 if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
547 ksft_print_msg("SME%d with FA64\n", sme_ver);
548 else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
549 ksft_print_msg("SME%d without FA64\n", sme_ver);
550
551 for (i = 0; i < ARRAY_SIZE(syscalls); i++)
552 test_one_syscall(&syscalls[i]);
553
554 ksft_print_cnts();
555
556 return 0;
557 }
558