1 // SPDX-License-Identifier: MIT
2 
3 /*
4  * Copyright (C) 2022 Advanced Micro Devices, Inc.
5  */
6 
7 #include <linux/dma-fence.h>
8 #include <linux/dma-fence-array.h>
9 #include <linux/dma-fence-chain.h>
10 #include <linux/dma-fence-unwrap.h>
11 
12 #include "selftest.h"
13 
14 #define CHAIN_SZ (4 << 10)
15 
16 struct mock_fence {
17 	struct dma_fence base;
18 	spinlock_t lock;
19 };
20 
mock_name(struct dma_fence * f)21 static const char *mock_name(struct dma_fence *f)
22 {
23 	return "mock";
24 }
25 
26 static const struct dma_fence_ops mock_ops = {
27 	.get_driver_name = mock_name,
28 	.get_timeline_name = mock_name,
29 };
30 
mock_fence(void)31 static struct dma_fence *mock_fence(void)
32 {
33 	struct mock_fence *f;
34 
35 	f = kmalloc(sizeof(*f), GFP_KERNEL);
36 	if (!f)
37 		return NULL;
38 
39 	spin_lock_init(&f->lock);
40 	dma_fence_init(&f->base, &mock_ops, &f->lock,
41 		       dma_fence_context_alloc(1), 1);
42 
43 	return &f->base;
44 }
45 
mock_array(unsigned int num_fences,...)46 static struct dma_fence *mock_array(unsigned int num_fences, ...)
47 {
48 	struct dma_fence_array *array;
49 	struct dma_fence **fences;
50 	va_list valist;
51 	int i;
52 
53 	fences = kcalloc(num_fences, sizeof(*fences), GFP_KERNEL);
54 	if (!fences)
55 		goto error_put;
56 
57 	va_start(valist, num_fences);
58 	for (i = 0; i < num_fences; ++i)
59 		fences[i] = va_arg(valist, typeof(*fences));
60 	va_end(valist);
61 
62 	array = dma_fence_array_create(num_fences, fences,
63 				       dma_fence_context_alloc(1),
64 				       1, false);
65 	if (!array)
66 		goto error_free;
67 	return &array->base;
68 
69 error_free:
70 	kfree(fences);
71 
72 error_put:
73 	va_start(valist, num_fences);
74 	for (i = 0; i < num_fences; ++i)
75 		dma_fence_put(va_arg(valist, typeof(*fences)));
76 	va_end(valist);
77 	return NULL;
78 }
79 
mock_chain(struct dma_fence * prev,struct dma_fence * fence)80 static struct dma_fence *mock_chain(struct dma_fence *prev,
81 				    struct dma_fence *fence)
82 {
83 	struct dma_fence_chain *f;
84 
85 	f = dma_fence_chain_alloc();
86 	if (!f) {
87 		dma_fence_put(prev);
88 		dma_fence_put(fence);
89 		return NULL;
90 	}
91 
92 	dma_fence_chain_init(f, prev, fence, 1);
93 	return &f->base;
94 }
95 
sanitycheck(void * arg)96 static int sanitycheck(void *arg)
97 {
98 	struct dma_fence *f, *chain, *array;
99 	int err = 0;
100 
101 	f = mock_fence();
102 	if (!f)
103 		return -ENOMEM;
104 
105 	dma_fence_enable_sw_signaling(f);
106 
107 	array = mock_array(1, f);
108 	if (!array)
109 		return -ENOMEM;
110 
111 	chain = mock_chain(NULL, array);
112 	if (!chain)
113 		return -ENOMEM;
114 
115 	dma_fence_put(chain);
116 	return err;
117 }
118 
unwrap_array(void * arg)119 static int unwrap_array(void *arg)
120 {
121 	struct dma_fence *fence, *f1, *f2, *array;
122 	struct dma_fence_unwrap iter;
123 	int err = 0;
124 
125 	f1 = mock_fence();
126 	if (!f1)
127 		return -ENOMEM;
128 
129 	dma_fence_enable_sw_signaling(f1);
130 
131 	f2 = mock_fence();
132 	if (!f2) {
133 		dma_fence_put(f1);
134 		return -ENOMEM;
135 	}
136 
137 	dma_fence_enable_sw_signaling(f2);
138 
139 	array = mock_array(2, f1, f2);
140 	if (!array)
141 		return -ENOMEM;
142 
143 	dma_fence_unwrap_for_each(fence, &iter, array) {
144 		if (fence == f1) {
145 			f1 = NULL;
146 		} else if (fence == f2) {
147 			f2 = NULL;
148 		} else {
149 			pr_err("Unexpected fence!\n");
150 			err = -EINVAL;
151 		}
152 	}
153 
154 	if (f1 || f2) {
155 		pr_err("Not all fences seen!\n");
156 		err = -EINVAL;
157 	}
158 
159 	dma_fence_put(array);
160 	return err;
161 }
162 
unwrap_chain(void * arg)163 static int unwrap_chain(void *arg)
164 {
165 	struct dma_fence *fence, *f1, *f2, *chain;
166 	struct dma_fence_unwrap iter;
167 	int err = 0;
168 
169 	f1 = mock_fence();
170 	if (!f1)
171 		return -ENOMEM;
172 
173 	dma_fence_enable_sw_signaling(f1);
174 
175 	f2 = mock_fence();
176 	if (!f2) {
177 		dma_fence_put(f1);
178 		return -ENOMEM;
179 	}
180 
181 	dma_fence_enable_sw_signaling(f2);
182 
183 	chain = mock_chain(f1, f2);
184 	if (!chain)
185 		return -ENOMEM;
186 
187 	dma_fence_unwrap_for_each(fence, &iter, chain) {
188 		if (fence == f1) {
189 			f1 = NULL;
190 		} else if (fence == f2) {
191 			f2 = NULL;
192 		} else {
193 			pr_err("Unexpected fence!\n");
194 			err = -EINVAL;
195 		}
196 	}
197 
198 	if (f1 || f2) {
199 		pr_err("Not all fences seen!\n");
200 		err = -EINVAL;
201 	}
202 
203 	dma_fence_put(chain);
204 	return err;
205 }
206 
unwrap_chain_array(void * arg)207 static int unwrap_chain_array(void *arg)
208 {
209 	struct dma_fence *fence, *f1, *f2, *array, *chain;
210 	struct dma_fence_unwrap iter;
211 	int err = 0;
212 
213 	f1 = mock_fence();
214 	if (!f1)
215 		return -ENOMEM;
216 
217 	dma_fence_enable_sw_signaling(f1);
218 
219 	f2 = mock_fence();
220 	if (!f2) {
221 		dma_fence_put(f1);
222 		return -ENOMEM;
223 	}
224 
225 	dma_fence_enable_sw_signaling(f2);
226 
227 	array = mock_array(2, f1, f2);
228 	if (!array)
229 		return -ENOMEM;
230 
231 	chain = mock_chain(NULL, array);
232 	if (!chain)
233 		return -ENOMEM;
234 
235 	dma_fence_unwrap_for_each(fence, &iter, chain) {
236 		if (fence == f1) {
237 			f1 = NULL;
238 		} else if (fence == f2) {
239 			f2 = NULL;
240 		} else {
241 			pr_err("Unexpected fence!\n");
242 			err = -EINVAL;
243 		}
244 	}
245 
246 	if (f1 || f2) {
247 		pr_err("Not all fences seen!\n");
248 		err = -EINVAL;
249 	}
250 
251 	dma_fence_put(chain);
252 	return err;
253 }
254 
unwrap_merge(void * arg)255 static int unwrap_merge(void *arg)
256 {
257 	struct dma_fence *fence, *f1, *f2, *f3;
258 	struct dma_fence_unwrap iter;
259 	int err = 0;
260 
261 	f1 = mock_fence();
262 	if (!f1)
263 		return -ENOMEM;
264 
265 	dma_fence_enable_sw_signaling(f1);
266 
267 	f2 = mock_fence();
268 	if (!f2) {
269 		err = -ENOMEM;
270 		goto error_put_f1;
271 	}
272 
273 	dma_fence_enable_sw_signaling(f2);
274 
275 	f3 = dma_fence_unwrap_merge(f1, f2);
276 	if (!f3) {
277 		err = -ENOMEM;
278 		goto error_put_f2;
279 	}
280 
281 	dma_fence_unwrap_for_each(fence, &iter, f3) {
282 		if (fence == f1) {
283 			dma_fence_put(f1);
284 			f1 = NULL;
285 		} else if (fence == f2) {
286 			dma_fence_put(f2);
287 			f2 = NULL;
288 		} else {
289 			pr_err("Unexpected fence!\n");
290 			err = -EINVAL;
291 		}
292 	}
293 
294 	if (f1 || f2) {
295 		pr_err("Not all fences seen!\n");
296 		err = -EINVAL;
297 	}
298 
299 	dma_fence_put(f3);
300 error_put_f2:
301 	dma_fence_put(f2);
302 error_put_f1:
303 	dma_fence_put(f1);
304 	return err;
305 }
306 
unwrap_merge_complex(void * arg)307 static int unwrap_merge_complex(void *arg)
308 {
309 	struct dma_fence *fence, *f1, *f2, *f3, *f4, *f5;
310 	struct dma_fence_unwrap iter;
311 	int err = -ENOMEM;
312 
313 	f1 = mock_fence();
314 	if (!f1)
315 		return -ENOMEM;
316 
317 	dma_fence_enable_sw_signaling(f1);
318 
319 	f2 = mock_fence();
320 	if (!f2)
321 		goto error_put_f1;
322 
323 	dma_fence_enable_sw_signaling(f2);
324 
325 	f3 = dma_fence_unwrap_merge(f1, f2);
326 	if (!f3)
327 		goto error_put_f2;
328 
329 	/* The resulting array has the fences in reverse */
330 	f4 = dma_fence_unwrap_merge(f2, f1);
331 	if (!f4)
332 		goto error_put_f3;
333 
334 	/* Signaled fences should be filtered, the two arrays merged. */
335 	f5 = dma_fence_unwrap_merge(f3, f4, dma_fence_get_stub());
336 	if (!f5)
337 		goto error_put_f4;
338 
339 	err = 0;
340 	dma_fence_unwrap_for_each(fence, &iter, f5) {
341 		if (fence == f1) {
342 			dma_fence_put(f1);
343 			f1 = NULL;
344 		} else if (fence == f2) {
345 			dma_fence_put(f2);
346 			f2 = NULL;
347 		} else {
348 			pr_err("Unexpected fence!\n");
349 			err = -EINVAL;
350 		}
351 	}
352 
353 	if (f1 || f2) {
354 		pr_err("Not all fences seen!\n");
355 		err = -EINVAL;
356 	}
357 
358 	dma_fence_put(f5);
359 error_put_f4:
360 	dma_fence_put(f4);
361 error_put_f3:
362 	dma_fence_put(f3);
363 error_put_f2:
364 	dma_fence_put(f2);
365 error_put_f1:
366 	dma_fence_put(f1);
367 	return err;
368 }
369 
dma_fence_unwrap(void)370 int dma_fence_unwrap(void)
371 {
372 	static const struct subtest tests[] = {
373 		SUBTEST(sanitycheck),
374 		SUBTEST(unwrap_array),
375 		SUBTEST(unwrap_chain),
376 		SUBTEST(unwrap_chain_array),
377 		SUBTEST(unwrap_merge),
378 		SUBTEST(unwrap_merge_complex),
379 	};
380 
381 	return subtests(tests, NULL);
382 }
383