1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (C) 2020 ARM Limited
3
4 #define _GNU_SOURCE
5
6 #include <errno.h>
7 #include <pthread.h>
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <time.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/mman.h>
15 #include <sys/prctl.h>
16 #include <sys/types.h>
17 #include <sys/wait.h>
18
19 #include "kselftest.h"
20 #include "mte_common_util.h"
21
22 #define PR_SET_TAGGED_ADDR_CTRL 55
23 #define PR_GET_TAGGED_ADDR_CTRL 56
24 # define PR_TAGGED_ADDR_ENABLE (1UL << 0)
25 # define PR_MTE_TCF_SHIFT 1
26 # define PR_MTE_TCF_NONE (0UL << PR_MTE_TCF_SHIFT)
27 # define PR_MTE_TCF_SYNC (1UL << PR_MTE_TCF_SHIFT)
28 # define PR_MTE_TCF_ASYNC (2UL << PR_MTE_TCF_SHIFT)
29 # define PR_MTE_TCF_MASK (3UL << PR_MTE_TCF_SHIFT)
30 # define PR_MTE_TAG_SHIFT 3
31 # define PR_MTE_TAG_MASK (0xffffUL << PR_MTE_TAG_SHIFT)
32
33 #include "mte_def.h"
34
35 #define NUM_ITERATIONS 1024
36 #define MAX_THREADS 5
37 #define THREAD_ITERATIONS 1000
38
execute_thread(void * x)39 void *execute_thread(void *x)
40 {
41 pid_t pid = *((pid_t *)x);
42 pid_t tid = gettid();
43 uint64_t prctl_tag_mask;
44 uint64_t prctl_set;
45 uint64_t prctl_get;
46 uint64_t prctl_tcf;
47
48 srand(time(NULL) ^ (pid << 16) ^ (tid << 16));
49
50 prctl_tag_mask = rand() & 0xffff;
51
52 if (prctl_tag_mask % 2)
53 prctl_tcf = PR_MTE_TCF_SYNC;
54 else
55 prctl_tcf = PR_MTE_TCF_ASYNC;
56
57 prctl_set = PR_TAGGED_ADDR_ENABLE | prctl_tcf | (prctl_tag_mask << PR_MTE_TAG_SHIFT);
58
59 for (int j = 0; j < THREAD_ITERATIONS; j++) {
60 if (prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_set, 0, 0, 0)) {
61 perror("prctl() failed");
62 goto fail;
63 }
64
65 prctl_get = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0);
66
67 if (prctl_set != prctl_get) {
68 ksft_print_msg("Error: prctl_set: 0x%lx != prctl_get: 0x%lx\n",
69 prctl_set, prctl_get);
70 goto fail;
71 }
72 }
73
74 return (void *)KSFT_PASS;
75
76 fail:
77 return (void *)KSFT_FAIL;
78 }
79
execute_test(pid_t pid)80 int execute_test(pid_t pid)
81 {
82 pthread_t thread_id[MAX_THREADS];
83 int thread_data[MAX_THREADS];
84
85 for (int i = 0; i < MAX_THREADS; i++)
86 pthread_create(&thread_id[i], NULL,
87 execute_thread, (void *)&pid);
88
89 for (int i = 0; i < MAX_THREADS; i++)
90 pthread_join(thread_id[i], (void *)&thread_data[i]);
91
92 for (int i = 0; i < MAX_THREADS; i++)
93 if (thread_data[i] == KSFT_FAIL)
94 return KSFT_FAIL;
95
96 return KSFT_PASS;
97 }
98
mte_gcr_fork_test(void)99 int mte_gcr_fork_test(void)
100 {
101 pid_t pid;
102 int results[NUM_ITERATIONS];
103 pid_t cpid;
104 int res;
105
106 for (int i = 0; i < NUM_ITERATIONS; i++) {
107 pid = fork();
108
109 if (pid < 0)
110 return KSFT_FAIL;
111
112 if (pid == 0) {
113 cpid = getpid();
114
115 res = execute_test(cpid);
116
117 exit(res);
118 }
119 }
120
121 for (int i = 0; i < NUM_ITERATIONS; i++) {
122 wait(&res);
123
124 if (WIFEXITED(res))
125 results[i] = WEXITSTATUS(res);
126 else
127 --i;
128 }
129
130 for (int i = 0; i < NUM_ITERATIONS; i++)
131 if (results[i] == KSFT_FAIL)
132 return KSFT_FAIL;
133
134 return KSFT_PASS;
135 }
136
main(int argc,char * argv[])137 int main(int argc, char *argv[])
138 {
139 int err;
140
141 err = mte_default_setup();
142 if (err)
143 return err;
144
145 ksft_set_plan(1);
146
147 evaluate_test(mte_gcr_fork_test(),
148 "Verify that GCR_EL1 is set correctly on context switch\n");
149
150 mte_restore_setup();
151 ksft_print_cnts();
152
153 return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL;
154 }
155