1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/thread_pool.h"
17 
18 #include <atomic>
19 #include <chrono>              // NOLINT(build/c++11)
20 #include <condition_variable>  // NOLINT(build/c++11)
21 #include <cstdint>
22 #include <cstdlib>
23 #include <memory>
24 #include <mutex>   // NOLINT(build/c++11)
25 #include <thread>  // NOLINT(build/c++11)
26 
27 #include "ruy/check_macros.h"
28 #include "ruy/denormal.h"
29 #include "ruy/trace.h"
30 #include "ruy/wait.h"
31 
32 namespace ruy {
33 
34 // A worker thread.
35 class Thread {
36  public:
37   enum class State {
38     Startup,  // The initial state before the thread main loop runs.
39     Ready,    // Is not working, has not yet received new work to do.
40     HasWork,  // Has work to do.
41     ExitAsSoonAsPossible  // Should exit at earliest convenience.
42   };
43 
Thread(BlockingCounter * counter_to_decrement_when_ready,Duration spin_duration)44   explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
45                   Duration spin_duration)
46       : task_(nullptr),
47         state_(State::Startup),
48         counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
49         spin_duration_(spin_duration) {
50     thread_.reset(new std::thread(ThreadFunc, this));
51   }
52 
~Thread()53   ~Thread() {
54     ChangeState(State::ExitAsSoonAsPossible);
55     thread_->join();
56   }
57 
58   // Changes State; may be called from either the worker thread
59   // or the master thread; however, not all state transitions are legal,
60   // which is guarded by assertions.
61   //
62   // The Task argument is to be used only with new_state==HasWork.
63   // It specifies the Task being handed to this Thread.
ChangeState(State new_state,Task * task=nullptr)64   void ChangeState(State new_state, Task* task = nullptr) {
65     state_mutex_.lock();
66     State old_state = state_.load(std::memory_order_relaxed);
67     RUY_DCHECK_NE(old_state, new_state);
68     switch (old_state) {
69       case State::Startup:
70         RUY_DCHECK_EQ(new_state, State::Ready);
71         break;
72       case State::Ready:
73         RUY_DCHECK(new_state == State::HasWork ||
74                    new_state == State::ExitAsSoonAsPossible);
75         break;
76       case State::HasWork:
77         RUY_DCHECK(new_state == State::Ready ||
78                    new_state == State::ExitAsSoonAsPossible);
79         break;
80       default:
81         abort();
82     }
83     switch (new_state) {
84       case State::Ready:
85         if (task_) {
86           // Doing work is part of reverting to 'ready' state.
87           task_->Run();
88           task_ = nullptr;
89         }
90         break;
91       case State::HasWork:
92         RUY_DCHECK(!task_);
93         task_ = task;
94         break;
95       default:
96         break;
97     }
98     state_.store(new_state, std::memory_order_relaxed);
99     state_cond_.notify_all();
100     state_mutex_.unlock();
101     if (new_state == State::Ready) {
102       counter_to_decrement_when_ready_->DecrementCount();
103     }
104   }
105 
ThreadFunc(Thread * arg)106   static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
107 
108   // Called by the master thead to give this thread work to do.
StartWork(Task * task)109   void StartWork(Task* task) { ChangeState(State::HasWork, task); }
110 
111  private:
112   // Thread entry point.
ThreadFuncImpl()113   void ThreadFuncImpl() {
114     RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
115     ChangeState(State::Ready);
116 
117     // Suppress denormals to avoid computation inefficiency.
118     ScopedSuppressDenormals suppress_denormals;
119 
120     // Thread main loop
121     while (true) {
122       RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
123       // In the 'Ready' state, we have nothing to do but to wait until
124       // we switch to another state.
125       const auto& condition = [this]() {
126         return state_.load(std::memory_order_acquire) != State::Ready;
127       };
128       RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
129       Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
130 
131       // Act on new state.
132       switch (state_.load(std::memory_order_acquire)) {
133         case State::HasWork: {
134           RUY_TRACE_SCOPE_NAME("Worker thread task");
135           // Got work to do! So do it, and then revert to 'Ready' state.
136           ChangeState(State::Ready);
137           break;
138         }
139         case State::ExitAsSoonAsPossible:
140           return;
141         default:
142           abort();
143       }
144     }
145   }
146 
147   // The underlying thread.
148   std::unique_ptr<std::thread> thread_;
149 
150   // The task to be worked on.
151   Task* task_;
152 
153   // The condition variable and mutex guarding state changes.
154   std::condition_variable state_cond_;
155   std::mutex state_mutex_;
156 
157   // The state enum tells if we're currently working, waiting for work, etc.
158   // Its concurrent accesses by the thread and main threads are guarded by
159   // state_mutex_, and can thus use memory_order_relaxed. This still needs
160   // to be a std::atomic because we use WaitForVariableChange.
161   std::atomic<State> state_;
162 
163   // pointer to the master's thread BlockingCounter object, to notify the
164   // master thread of when this thread switches to the 'Ready' state.
165   BlockingCounter* const counter_to_decrement_when_ready_;
166 
167   // See ThreadPool::spin_duration_.
168   const Duration spin_duration_;
169 };
170 
ExecuteImpl(int task_count,int stride,Task * tasks)171 void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
172   RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
173   RUY_DCHECK_GE(task_count, 1);
174 
175   // Case of 1 thread: just run the single task on the current thread.
176   if (task_count == 1) {
177     (tasks + 0)->Run();
178     return;
179   }
180 
181   // Task #0 will be run on the current thread.
182   CreateThreads(task_count - 1);
183   counter_to_decrement_when_ready_.Reset(task_count - 1);
184   for (int i = 1; i < task_count; i++) {
185     RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
186     auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
187     threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
188   }
189 
190   RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
191   // Execute task #0 immediately on the current thread.
192   (tasks + 0)->Run();
193 
194   RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
195   // Wait for the threads submitted above to finish.
196   counter_to_decrement_when_ready_.Wait(spin_duration_);
197 }
198 
199 // Ensures that the pool has at least the given count of threads.
200 // If any new thread has to be created, this function waits for it to
201 // be ready.
CreateThreads(int threads_count)202 void ThreadPool::CreateThreads(int threads_count) {
203   RUY_DCHECK_GE(threads_count, 0);
204   unsigned int unsigned_threads_count = threads_count;
205   if (threads_.size() >= unsigned_threads_count) {
206     return;
207   }
208   counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
209   while (threads_.size() < unsigned_threads_count) {
210     threads_.push_back(
211         new Thread(&counter_to_decrement_when_ready_, spin_duration_));
212   }
213   counter_to_decrement_when_ready_.Wait(spin_duration_);
214 }
215 
~ThreadPool()216 ThreadPool::~ThreadPool() {
217   for (auto w : threads_) {
218     delete w;
219   }
220 }
221 
222 }  // end namespace ruy
223