-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathThreadPool.hpp
More file actions
115 lines (101 loc) · 3.19 KB
/
Copy pathThreadPool.hpp
File metadata and controls
115 lines (101 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <thread>
#include <vector>
#include <queue>
#include <atomic>
#include <memory>
#include <functional>
#include <iostream>
#include <string>
#include <future>
#include <functional>
#include "ThreadSafeQueue.hpp"
#include "StealingWorkQueue.hpp"
#ifndef __FUNCTION_WRAPPER_
#define __FUNCTION_WRAPPER_
#include "MoveOnlyFunctionWrapper.hpp"
#endif
class ThreadPool {
using task_type = move_only_function_wrapper;
using local_queue_type = StealingWorkQueue;
unsigned int n_threads;
std::vector<std::thread> threads;
std::vector<std::unique_ptr<local_queue_type>> local_queues;
ThreadSafeQueue<task_type> work_queue;
std::atomic_bool done;
inline static thread_local local_queue_type * local_work_queue;
inline static thread_local size_t thread_index;
void worker_thread(size_t index) {
thread_index = index;
local_work_queue = local_queues[thread_index].get();
while (!done) {
run_pending_task();
}
}
bool steal_work_from_other_threads(task_type& task) {
auto local_queue_size = local_queues.size();
for (size_t index = 0; index < local_queue_size - 1; ++index) {
auto other_thread_index = (thread_index + index + 1) % local_queue_size;
if (local_queues[other_thread_index]->try_steal(task)) {
return true;
}
}
return false;
}
public:
ThreadPool(unsigned num_worker_threads):
done(false),
threads (),
local_queues (),
n_threads(std::move(num_worker_threads))
{
try {
for (unsigned i = 0; i < n_threads; ++i) {
threads.push_back(std::thread(&ThreadPool::worker_thread, this, static_cast<size_t>(i)));
local_queues.push_back(std::make_unique<local_queue_type>());
}
}
catch (...) {
std::cout << "some errors" << std::endl;
done = true;
throw;
}
}
template <typename FunctionType, typename ... Args, typename ResultType = typename std::result_of<FunctionType(Args...)>::type>
std::future<ResultType> submit(FunctionType&& f, Args&& ... args) {
std::packaged_task<ResultType()> task(
std::bind(std::forward<FunctionType>(f), std::forward<Args>(args)...)
);
std::future<ResultType> res (task.get_future());
if (local_work_queue) {
local_work_queue->push(std::move(task));
}
else {
work_queue.push(std::move(task));
}
return res;
}
void run_pending_task() {
task_type task;
if (local_work_queue && local_work_queue->try_pop(task)) {
task();
}
else if (steal_work_from_other_threads(task)) {
task();
}
else if (work_queue.try_pop(task)) // will call move assignment
{
task();
}
else {
std::this_thread::yield();
}
}
~ThreadPool() {
done = true;
for (unsigned i = 0; i < threads.size(); ++i) {
if (threads[i].joinable()) {
threads[i].join();
}
}
}
};