1 // Copyright 2017 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <pthread.h>
6
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11
12 #include <unittest/unittest.h>
13
14 #define kNumThreads 16
15 #define kNumIterations 128
16
17 static pthread_barrier_t barrier;
18 static pthread_t threads[kNumThreads];
19 static int barriers_won[kNumThreads];
20
barrier_wait_test(size_t idx)21 static bool barrier_wait_test(size_t idx) {
22 BEGIN_HELPER;
23
24 for (int iteration = 0u; iteration < kNumIterations; iteration++) {
25 int result = pthread_barrier_wait(&barrier);
26 if (result == PTHREAD_BARRIER_SERIAL_THREAD) {
27 barriers_won[idx] += 1;
28 } else {
29 ASSERT_EQ(result, 0,
30 "Invalid return value from pthread_barrier_wait");
31 }
32 }
33
34 END_HELPER;
35 }
36
barrier_wait(void * arg)37 static void* barrier_wait(void* arg) {
38 // The real work is in the subroutine because functions using
39 // ASSERT_* macros must return bool.
40 (void)barrier_wait_test((uintptr_t)arg);
41 return NULL;
42 }
43
test_barrier(void)44 static bool test_barrier(void) {
45 BEGIN_TEST;
46
47 ASSERT_EQ(pthread_barrier_init(&barrier, NULL, kNumThreads), 0, "Failed to initialize barrier!");
48
49 for (int idx = 0; idx < kNumThreads; ++idx) {
50 ASSERT_EQ(pthread_create(&threads[idx], NULL,
51 &barrier_wait, (void*)(uintptr_t)idx), 0,
52 "Failed to create thread!");
53 }
54
55 for (int idx = 0; idx < kNumThreads; ++idx) {
56 ASSERT_EQ(pthread_join(threads[idx], NULL), 0, "Failed to join thread!");
57 }
58
59 int total_barriers_won = 0;
60 for (int idx = 0; idx < kNumThreads; ++idx) {
61 total_barriers_won += barriers_won[idx];
62 }
63 ASSERT_EQ(total_barriers_won, kNumIterations, "Barrier busted!");
64
65 END_TEST;
66 }
67
68 BEGIN_TEST_CASE(pthread_barrier_tests)
RUN_TEST(test_barrier)69 RUN_TEST(test_barrier)
70 END_TEST_CASE(pthread_barrier_tests)
71
72 #ifndef BUILD_COMBINED_TESTS
73 int main(int argc, char** argv) {
74 return unittest_run_all_tests(argc, argv) ? 0 : -1;
75 }
76 #endif
77