1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "ruy/thread_pool.h"
17
18 #include <atomic>
19 #include <chrono> // NOLINT(build/c++11)
20 #include <condition_variable> // NOLINT(build/c++11)
21 #include <cstdint>
22 #include <cstdlib>
23 #include <memory>
24 #include <mutex> // NOLINT(build/c++11)
25 #include <thread> // NOLINT(build/c++11)
26
27 #include "ruy/check_macros.h"
28 #include "ruy/denormal.h"
29 #include "ruy/trace.h"
30 #include "ruy/wait.h"
31
32 namespace ruy {
33
34 // A worker thread.
35 class Thread {
36 public:
37 enum class State {
38 Startup, // The initial state before the thread main loop runs.
39 Ready, // Is not working, has not yet received new work to do.
40 HasWork, // Has work to do.
41 ExitAsSoonAsPossible // Should exit at earliest convenience.
42 };
43
Thread(BlockingCounter * counter_to_decrement_when_ready,Duration spin_duration)44 explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
45 Duration spin_duration)
46 : task_(nullptr),
47 state_(State::Startup),
48 counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
49 spin_duration_(spin_duration) {
50 thread_.reset(new std::thread(ThreadFunc, this));
51 }
52
~Thread()53 ~Thread() {
54 ChangeState(State::ExitAsSoonAsPossible);
55 thread_->join();
56 }
57
58 // Changes State; may be called from either the worker thread
59 // or the master thread; however, not all state transitions are legal,
60 // which is guarded by assertions.
61 //
62 // The Task argument is to be used only with new_state==HasWork.
63 // It specifies the Task being handed to this Thread.
ChangeState(State new_state,Task * task=nullptr)64 void ChangeState(State new_state, Task* task = nullptr) {
65 state_mutex_.lock();
66 State old_state = state_.load(std::memory_order_relaxed);
67 RUY_DCHECK_NE(old_state, new_state);
68 switch (old_state) {
69 case State::Startup:
70 RUY_DCHECK_EQ(new_state, State::Ready);
71 break;
72 case State::Ready:
73 RUY_DCHECK(new_state == State::HasWork ||
74 new_state == State::ExitAsSoonAsPossible);
75 break;
76 case State::HasWork:
77 RUY_DCHECK(new_state == State::Ready ||
78 new_state == State::ExitAsSoonAsPossible);
79 break;
80 default:
81 abort();
82 }
83 switch (new_state) {
84 case State::Ready:
85 if (task_) {
86 // Doing work is part of reverting to 'ready' state.
87 task_->Run();
88 task_ = nullptr;
89 }
90 break;
91 case State::HasWork:
92 RUY_DCHECK(!task_);
93 task_ = task;
94 break;
95 default:
96 break;
97 }
98 state_.store(new_state, std::memory_order_relaxed);
99 state_cond_.notify_all();
100 state_mutex_.unlock();
101 if (new_state == State::Ready) {
102 counter_to_decrement_when_ready_->DecrementCount();
103 }
104 }
105
ThreadFunc(Thread * arg)106 static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
107
108 // Called by the master thead to give this thread work to do.
StartWork(Task * task)109 void StartWork(Task* task) { ChangeState(State::HasWork, task); }
110
111 private:
112 // Thread entry point.
ThreadFuncImpl()113 void ThreadFuncImpl() {
114 RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
115 ChangeState(State::Ready);
116
117 // Suppress denormals to avoid computation inefficiency.
118 ScopedSuppressDenormals suppress_denormals;
119
120 // Thread main loop
121 while (true) {
122 RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
123 // In the 'Ready' state, we have nothing to do but to wait until
124 // we switch to another state.
125 const auto& condition = [this]() {
126 return state_.load(std::memory_order_acquire) != State::Ready;
127 };
128 RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
129 Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
130
131 // Act on new state.
132 switch (state_.load(std::memory_order_acquire)) {
133 case State::HasWork: {
134 RUY_TRACE_SCOPE_NAME("Worker thread task");
135 // Got work to do! So do it, and then revert to 'Ready' state.
136 ChangeState(State::Ready);
137 break;
138 }
139 case State::ExitAsSoonAsPossible:
140 return;
141 default:
142 abort();
143 }
144 }
145 }
146
147 // The underlying thread.
148 std::unique_ptr<std::thread> thread_;
149
150 // The task to be worked on.
151 Task* task_;
152
153 // The condition variable and mutex guarding state changes.
154 std::condition_variable state_cond_;
155 std::mutex state_mutex_;
156
157 // The state enum tells if we're currently working, waiting for work, etc.
158 // Its concurrent accesses by the thread and main threads are guarded by
159 // state_mutex_, and can thus use memory_order_relaxed. This still needs
160 // to be a std::atomic because we use WaitForVariableChange.
161 std::atomic<State> state_;
162
163 // pointer to the master's thread BlockingCounter object, to notify the
164 // master thread of when this thread switches to the 'Ready' state.
165 BlockingCounter* const counter_to_decrement_when_ready_;
166
167 // See ThreadPool::spin_duration_.
168 const Duration spin_duration_;
169 };
170
ExecuteImpl(int task_count,int stride,Task * tasks)171 void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
172 RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
173 RUY_DCHECK_GE(task_count, 1);
174
175 // Case of 1 thread: just run the single task on the current thread.
176 if (task_count == 1) {
177 (tasks + 0)->Run();
178 return;
179 }
180
181 // Task #0 will be run on the current thread.
182 CreateThreads(task_count - 1);
183 counter_to_decrement_when_ready_.Reset(task_count - 1);
184 for (int i = 1; i < task_count; i++) {
185 RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
186 auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
187 threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
188 }
189
190 RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
191 // Execute task #0 immediately on the current thread.
192 (tasks + 0)->Run();
193
194 RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
195 // Wait for the threads submitted above to finish.
196 counter_to_decrement_when_ready_.Wait(spin_duration_);
197 }
198
199 // Ensures that the pool has at least the given count of threads.
200 // If any new thread has to be created, this function waits for it to
201 // be ready.
CreateThreads(int threads_count)202 void ThreadPool::CreateThreads(int threads_count) {
203 RUY_DCHECK_GE(threads_count, 0);
204 unsigned int unsigned_threads_count = threads_count;
205 if (threads_.size() >= unsigned_threads_count) {
206 return;
207 }
208 counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
209 while (threads_.size() < unsigned_threads_count) {
210 threads_.push_back(
211 new Thread(&counter_to_decrement_when_ready_, spin_duration_));
212 }
213 counter_to_decrement_when_ready_.Wait(spin_duration_);
214 }
215
~ThreadPool()216 ThreadPool::~ThreadPool() {
217 for (auto w : threads_) {
218 delete w;
219 }
220 }
221
222 } // end namespace ruy
223