1 /*
2  * Copyright 2024 by Garmin Ltd. or its subsidiaries.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 /**
8  * @file
9  * @brief Test complex mutex priority inversion
10  *
11  * This module demonstrates the kernel's priority inheritance algorithm
12  * with two mutexes and four threads, ensuring that boosting priority of
13  * a thread waiting on another mutex does not break assumptions of the
14  * mutex's waitq, causing the incorrect thread to run or a crash.
15  *
16  * Sequence for priority inheritance testing:
17  *  - thread_08 takes mutex_1
18  *  - thread_07 takes mutex_0 then waits on mutex_1
19  *  - thread_06 waits on mutex_1
20  *  - thread_05 waits on mutex_0, boosting priority of thread_07
21  *  - thread_08 gives mutex_1, thread_07 takes mutex_1
22  *  - thread_07 gives mutex_1, thread_06 takes mutex_1
23  *  - thread_07 gives mutex_0, thread_05 takes mutex_0
24  *  - thread_06 gives mutex_1
25  *  - thread_05 gives mutex_0
26  */
27 
28 #include <zephyr/tc_util.h>
29 #include <zephyr/kernel.h>
30 #include <zephyr/ztest.h>
31 #include <zephyr/sys/mutex.h>
32 
33 #define STACKSIZE (512 + CONFIG_TEST_EXTRA_STACK_SIZE)
34 
35 static ZTEST_DMEM int tc_rc = TC_PASS; /* test case return code */
36 
37 static K_MUTEX_DEFINE(mutex_0);
38 static K_MUTEX_DEFINE(mutex_1);
39 
40 #define PARTICIPANT_THREAD_OPTIONS (K_INHERIT_PERMS)
41 
42 #define DEFINE_PARTICIPANT_THREAD(id)                                                              \
43 	static K_THREAD_STACK_DEFINE(thread_##id##_stack_area, STACKSIZE);                         \
44 	static struct k_thread thread_##id##_thread_data;                                          \
45 	static k_tid_t thread_##id##_tid;                                                          \
46 	static K_SEM_DEFINE(thread_##id##_wait, 0, 1);                                             \
47 	static K_SEM_DEFINE(thread_##id##_done, 0, 1);
48 
49 #define CREATE_PARTICIPANT_THREAD(id, pri)                                                         \
50 	thread_##id##_tid = k_thread_create(&thread_##id##_thread_data, thread_##id##_stack_area,  \
51 					    K_THREAD_STACK_SIZEOF(thread_##id##_stack_area),       \
52 					    (k_thread_entry_t)thread_##id, &thread_##id##_wait,    \
53 					    &thread_##id##_done, NULL, pri,                        \
54 					    PARTICIPANT_THREAD_OPTIONS, K_FOREVER);                \
55 	k_thread_name_set(thread_##id##_tid, "thread_" STRINGIFY(id));
56 #define START_PARTICIPANT_THREAD(id) k_thread_start(&(thread_##id##_thread_data));
57 #define JOIN_PARTICIPANT_THREAD(id)  k_thread_join(&(thread_##id##_thread_data), K_FOREVER);
58 
59 #define WAIT_FOR_MAIN()                                                                            \
60 	k_sem_give(done);                                                                          \
61 	k_sem_take(wait, K_FOREVER);
62 
63 #define ADVANCE_THREAD(id)                                                                         \
64 	SIGNAL_THREAD(id);                                                                         \
65 	WAIT_FOR_THREAD(id);
66 
67 #define SIGNAL_THREAD(id) k_sem_give(&thread_##id##_wait);
68 
69 #define WAIT_FOR_THREAD(id) zassert_ok(k_sem_take(&thread_##id##_done, K_MSEC(100)));
70 
71 /**
72  *
73  * thread_05 -
74  *
75  */
76 
thread_05(struct k_sem * wait,struct k_sem * done)77 static void thread_05(struct k_sem *wait, struct k_sem *done)
78 {
79 	int rv;
80 
81 	/*
82 	 * Wait for mutex_0, boosting the priority of thread_07 so it will lock mutex_1 first.
83 	 */
84 
85 	WAIT_FOR_MAIN();
86 
87 	rv = k_mutex_lock(&mutex_0, K_FOREVER);
88 	if (rv != 0) {
89 		tc_rc = TC_FAIL;
90 		TC_ERROR("Failed to take mutex %p\n", &mutex_0);
91 		return;
92 	}
93 
94 	WAIT_FOR_MAIN();
95 
96 	k_mutex_unlock(&mutex_0);
97 }
98 
99 /**
100  *
101  * thread_06 -
102  *
103  */
104 
thread_06(struct k_sem * wait,struct k_sem * done)105 static void thread_06(struct k_sem *wait, struct k_sem *done)
106 {
107 	int rv;
108 
109 	/*
110 	 * Wait for mutex_1. Initially it will be the highest priority waiter, but
111 	 * thread_07 will be boosted above thread_06 so thread_07 will lock it first.
112 	 */
113 
114 	WAIT_FOR_MAIN();
115 
116 	rv = k_mutex_lock(&mutex_1, K_FOREVER);
117 	if (rv != 0) {
118 		tc_rc = TC_FAIL;
119 		TC_ERROR("Failed to take mutex %p\n", &mutex_1);
120 		return;
121 	}
122 
123 	WAIT_FOR_MAIN();
124 
125 	k_mutex_unlock(&mutex_1);
126 }
127 
128 /**
129  *
130  * thread_07 -
131  *
132  */
133 
thread_07(struct k_sem * wait,struct k_sem * done)134 static void thread_07(struct k_sem *wait, struct k_sem *done)
135 {
136 	int rv;
137 
138 	/*
139 	 * Lock mutex_0 and wait for mutex_1. After thread_06 is also waiting for
140 	 * mutex_1, thread_05 will wait for mutex_0, boosting the priority for
141 	 * thread_07 so it should lock mutex_1 first when it is unlocked by thread_08.
142 	 */
143 
144 	WAIT_FOR_MAIN();
145 
146 	rv = k_mutex_lock(&mutex_0, K_NO_WAIT);
147 	if (rv != 0) {
148 		tc_rc = TC_FAIL;
149 		TC_ERROR("Failed to take mutex %p\n", &mutex_0);
150 		return;
151 	}
152 
153 	WAIT_FOR_MAIN();
154 
155 	rv = k_mutex_lock(&mutex_1, K_FOREVER);
156 	if (rv != 0) {
157 		tc_rc = TC_FAIL;
158 		TC_ERROR("Failed to take mutex %p\n", &mutex_1);
159 		k_mutex_unlock(&mutex_0);
160 		return;
161 	}
162 
163 	WAIT_FOR_MAIN();
164 
165 	k_mutex_unlock(&mutex_1);
166 	k_mutex_unlock(&mutex_0);
167 }
168 
169 /**
170  *
171  * thread_08 -
172  *
173  */
174 
thread_08(struct k_sem * wait,struct k_sem * done)175 static void thread_08(struct k_sem *wait, struct k_sem *done)
176 {
177 	int rv;
178 
179 	/*
180 	 * Lock mutex_1 and hold until priority has been boosted on thread_07
181 	 * to ensure that thread_07 is the first to lock mutex_1 when thread_08
182 	 * unlocks it.
183 	 */
184 
185 	WAIT_FOR_MAIN();
186 
187 	rv = k_mutex_lock(&mutex_1, K_NO_WAIT);
188 	if (rv != 0) {
189 		tc_rc = TC_FAIL;
190 		TC_ERROR("Failed to take mutex %p\n", &mutex_1);
191 		return;
192 	}
193 
194 	WAIT_FOR_MAIN();
195 
196 	k_mutex_unlock(&mutex_1);
197 }
198 
199 DEFINE_PARTICIPANT_THREAD(05);
200 DEFINE_PARTICIPANT_THREAD(06);
201 DEFINE_PARTICIPANT_THREAD(07);
202 DEFINE_PARTICIPANT_THREAD(08);
203 
create_participant_threads(void)204 static void create_participant_threads(void)
205 {
206 	CREATE_PARTICIPANT_THREAD(05, 5);
207 	CREATE_PARTICIPANT_THREAD(06, 6);
208 	CREATE_PARTICIPANT_THREAD(07, 7);
209 	CREATE_PARTICIPANT_THREAD(08, 8);
210 }
211 
start_participant_threads(void)212 static void start_participant_threads(void)
213 {
214 	START_PARTICIPANT_THREAD(05);
215 	START_PARTICIPANT_THREAD(06);
216 	START_PARTICIPANT_THREAD(07);
217 	START_PARTICIPANT_THREAD(08);
218 }
219 
join_participant_threads(void)220 static void join_participant_threads(void)
221 {
222 	JOIN_PARTICIPANT_THREAD(05);
223 	JOIN_PARTICIPANT_THREAD(06);
224 	JOIN_PARTICIPANT_THREAD(07);
225 	JOIN_PARTICIPANT_THREAD(08);
226 }
227 
228 /**
229  *
230  * @brief Main thread to test mutex locking
231  *
232  * This thread orchestrates mutex locking on other threads and verifies that
233  * the correct thread is holding mutexes at any given step.
234  *
235  */
236 
ZTEST(mutex_api,test_complex_inversion)237 ZTEST(mutex_api, test_complex_inversion)
238 {
239 	create_participant_threads();
240 	start_participant_threads();
241 
242 	/* Wait for all the threads to start up */
243 	WAIT_FOR_THREAD(08);
244 	WAIT_FOR_THREAD(07);
245 	WAIT_FOR_THREAD(06);
246 	WAIT_FOR_THREAD(05);
247 
248 	ADVANCE_THREAD(08); /* thread_08 takes mutex_1 */
249 	zassert_equal(thread_08_tid, mutex_1.owner, "expected owner %s, not %s\n",
250 		      thread_08_tid->name, mutex_1.owner->name);
251 
252 	ADVANCE_THREAD(07); /* thread_07 takes mutex_0 */
253 	zassert_equal(thread_07_tid, mutex_0.owner, "expected owner %s, not %s\n",
254 		      thread_07_tid->name, mutex_0.owner->name);
255 
256 	SIGNAL_THREAD(07);    /* thread_07 waits on mutex_1 */
257 	k_sleep(K_MSEC(100)); /* Give thread_07 some time to wait on mutex_1 */
258 
259 	SIGNAL_THREAD(06);    /* thread_06 waits on mutex_1 */
260 	k_sleep(K_MSEC(100)); /* Give thread_06 some time to wait on mutex_1 */
261 
262 	SIGNAL_THREAD(05); /* thread_05 waits on mutex_0, boosting priority of thread_07 */
263 
264 	SIGNAL_THREAD(08); /* thread_08 gives mutex_1 */
265 
266 	/* If thread_06 erroneously took mutex_1, giving it could cause a crash
267 	 * when CONFIG_WAITQ_SCALABLE is set. Give it a chance to run to make sure
268 	 * this crash isn't hit.
269 	 */
270 	SIGNAL_THREAD(06);
271 
272 	WAIT_FOR_THREAD(07); /* thread_07 takes mutex_1 */
273 	zassert_equal(thread_07_tid, mutex_1.owner, "expected owner %s, not %s\n",
274 		      thread_07_tid->name, mutex_1.owner->name);
275 
276 	SIGNAL_THREAD(07);   /* thread_07 gives mutex_1 then gives mutex_0 */
277 	WAIT_FOR_THREAD(06); /* thread_06 takes mutex_1 */
278 	WAIT_FOR_THREAD(05); /* thread_05 takes mutex_0 */
279 	zassert_equal(thread_06_tid, mutex_1.owner, "expected owner %s, not %s\n",
280 		      thread_06_tid->name, mutex_1.owner->name);
281 	zassert_equal(thread_05_tid, mutex_0.owner, "expected owner %s, not %s\n",
282 		      thread_05_tid->name, mutex_0.owner->name);
283 
284 	SIGNAL_THREAD(06); /* thread_06 gives mutex_1 */
285 	SIGNAL_THREAD(05); /* thread_05 gives mutex_0 */
286 
287 	zassert_equal(tc_rc, TC_PASS);
288 
289 	join_participant_threads();
290 }
291