代码拉取完成,页面将自动刷新
//Rewrite from https://github.com/David-Haim/concurrencpp
//Test successfully on VS2019 16.11.3
#include <cassert>
#include <chrono>
#include <coroutine>
#include <deque>
#include <iostream>
#include <list>
#include <mutex>
#include <queue>
#include <semaphore>
#include <set>
#include <span>
#include <thread>
#include <unordered_map>
#include <utility>
#include <vector>
#include <Windows.h> //for thread support on OS
using namespace std;
namespace concurrencpp
{
namespace errors //注意构造函数直接 using
{
struct empty_object : public runtime_error { using runtime_error::runtime_error; };
struct empty_result : public empty_object { using empty_object::empty_object; };
struct empty_result_promise : public empty_object { using empty_object::empty_object; };
struct empty_awaitable : public empty_object { using empty_object::empty_object; };
struct broken_task : public runtime_error { using runtime_error::runtime_error; };
struct result_already_retrieved : public runtime_error { using runtime_error::runtime_error; };
struct runtime_shutdown : public runtime_error { using runtime_error::runtime_error; };
} //namespace errors
namespace details
{
class await_context //可接收一个协程句柄随后恢复,或一个异常随后抛出
{
private:
coroutine_handle<void> m_caller_handle;
exception_ptr m_interrupt_exception; //异常指针,可用 make_exception_ptr 构造
public:
void resume() noexcept
{
assert(static_cast<bool>(m_caller_handle));
assert(!m_caller_handle.done());
m_caller_handle();
}
void set_coro_handle(coroutine_handle<void> coro_handle) noexcept
{
assert(!static_cast<bool>(m_caller_handle));
assert(static_cast<bool>(coro_handle));
assert(!coro_handle.done());
m_caller_handle = coro_handle;
}
void set_interrupt(const exception_ptr& interrupt) noexcept
{
assert(m_interrupt_exception == nullptr);
assert(static_cast<bool>(interrupt));
m_interrupt_exception = interrupt;
}
void throw_if_interrupted() const
{
if (m_interrupt_exception != nullptr)
rethrow_exception(m_interrupt_exception); //标准库函数
}
}; //class await_context
class wait_context //可设定多个定时/不定时等待,设定通知后恢复
{
private:
mutex m_lock;
condition_variable m_condition;
bool m_ready = false; //默认等待,直到 notify
public:
void wait() noexcept
{
unique_lock<mutex> lock(m_lock);
m_condition.wait(lock, [this] { return m_ready; });
}
bool wait_for(size_t milliseconds) noexcept //true 等待成功,false 超时
{
unique_lock<mutex> lock(m_lock);
return m_condition.wait_for(lock, chrono::milliseconds(milliseconds), [this] { return m_ready; });
}
void notify() noexcept
{
{ unique_lock<mutex> lock(m_lock); m_ready = true; } //解锁越快越好
m_condition.notify_all();
}
}; //class wait_context
class result_state_base;
class when_any_context //可接收一个协程句柄,并在设定完成结果时恢复它
{
private:
atomic_bool m_fulfilled = false;
result_state_base* m_completed_result = nullptr;
coroutine_handle<void> m_coro_handle;
public:
when_any_context(coroutine_handle<void> coro_handle) noexcept : m_coro_handle(coro_handle) {}
bool fulfilled() const noexcept { return m_fulfilled.load(memory_order_acquire); }
result_state_base* completed_result() const noexcept
{
assert(m_completed_result != nullptr);
return m_completed_result;
}
void try_resume(result_state_base* completed_result) noexcept //对结果赋值后尝试恢复协程
{
assert(completed_result != nullptr);
const auto already_resumed = m_fulfilled.exchange(true, memory_order_acq_rel);
if (already_resumed) return;
assert(m_completed_result == nullptr);
m_completed_result = completed_result;
assert(static_cast<bool>(m_coro_handle));
m_coro_handle();
}
}; //class when_any_context
enum class result_status { idle, value, exception };
template<class type>
class producer_context //负责通过 build_result/exception 生成结果值或异常,由 get 获取
{
union storage
{
type object;
exception_ptr exception;
storage() noexcept {}
~storage() noexcept {}
};
private:
storage m_storage;
result_status m_status = result_status::idle;
public:
~producer_context() noexcept
{
switch (m_status) {
case result_status::value: m_storage.object.~type(); break; //注意这种析构用法,type 是模板参数
case result_status::exception: m_storage.exception.~exception_ptr(); break;
case result_status::idle: break;
default: assert(false);
}
}
producer_context& operator=(producer_context&& rhs) noexcept
{
assert(m_status == result_status::idle);
m_status = exchange(rhs.m_status, result_status::idle); //rhs 换过来,rhs 再变成 idle
switch (m_status) {
case result_status::value:
{
new (addressof(m_storage.object)) type(move(rhs.m_storage.object)); //预分配构造
rhs.m_storage.object.~type();
break;
}
case result_status::exception:
{
new (addressof(m_storage.exception)) exception_ptr(rhs.m_storage.exception);
rhs.m_storage.exception.~exception_ptr();
break;
}
case result_status::idle: break;
default: assert(false);
}
return *this;
}
template<class... argument_types>
void build_result(argument_types&&... arguments) noexcept(noexcept(type(forward<argument_types>(arguments)...)))
{ //生成类型结果。只要相应构造函数 noexcept,则此函数 noexcept
assert(m_status == result_status::idle);
new (addressof(m_storage.object)) type(forward<argument_types>(arguments)...);
m_status = result_status::value;
}
void build_exception(const exception_ptr& exception) noexcept //异常的构造一定是 noexcept
{
assert(m_status == result_status::idle);
new (addressof(m_storage.exception)) exception_ptr(exception);
m_status = result_status::exception;
}
result_status status() const noexcept { return m_status; }
type get() { return move(get_ref()); } //会调用移动构造函数
type& get_ref()
{
assert(m_status != result_status::idle);
if (m_status == result_status::value)
return m_storage.object;
assert(m_status == result_status::exception);
assert(static_cast<bool>(m_storage.exception));
rethrow_exception(m_storage.exception); //发生了异常,就在 get() 时重新抛出
}
};
template<>
class producer_context<void>
{
union storage
{
exception_ptr exception{};
storage() noexcept {}
~storage() noexcept {}
};
private:
storage m_storage;
result_status m_status = result_status::idle;
public:
~producer_context() noexcept
{
if (m_status == result_status::exception)
m_storage.exception.~exception_ptr();
}
producer_context& operator=(producer_context&& rhs) noexcept
{
assert(m_status == result_status::idle);
m_status = exchange(rhs.m_status, result_status::idle);
if (m_status == result_status::exception) {
new (addressof(m_storage.exception)) exception_ptr(rhs.m_storage.exception);
rhs.m_storage.exception.~exception_ptr();
}
return *this;
}
void build_result() noexcept
{
assert(m_status == result_status::idle);
m_status = result_status::value;
}
void build_exception(const exception_ptr& exception) noexcept
{
assert(m_status == result_status::idle);
new (addressof(m_storage.exception)) exception_ptr(exception);
m_status = result_status::exception;
}
result_status status() const noexcept { return m_status; }
void get() const { get_ref(); }
void get_ref() const
{
assert(m_status != result_status::idle);
if (m_status == result_status::exception) {
assert(static_cast<bool>(m_storage.exception));
rethrow_exception(m_storage.exception);
}
}
};
template<class type>
class producer_context<type&>
{
union storage
{
type* pointer; //此时内部保存的是原始指针
exception_ptr exception;
storage() noexcept {}
~storage() noexcept {}
};
private:
storage m_storage;
result_status m_status = result_status::idle;
public:
~producer_context() noexcept
{
if (m_status == result_status::exception)
m_storage.exception.~exception_ptr();
}
producer_context& operator=(producer_context&& rhs) noexcept
{
assert(m_status == result_status::idle);
m_status = exchange(rhs.m_status, result_status::idle);
switch (m_status) {
case result_status::value: m_storage.pointer = rhs.m_storage.pointer; break;
case result_status::exception:
{
new (addressof(m_storage.exception)) exception_ptr(rhs.m_storage.exception);
rhs.m_storage.exception.~exception_ptr();
break;
}
case result_status::idle: break;
default: assert(false);
}
return *this;
}
void build_result(type& reference) noexcept
{
assert(m_status == result_status::idle);
auto pointer = addressof(reference); //注意对引用对象取地址的方法
assert(pointer != nullptr);
assert(reinterpret_cast<size_t>(pointer) % alignof(type) == 0); //注意对齐检查
m_storage.pointer = pointer;
m_status = result_status::value;
}
void build_exception(const exception_ptr& exception) noexcept
{
assert(m_status == result_status::idle);
new (addressof(m_storage.exception)) exception_ptr(exception);
m_status = result_status::exception;
}
result_status status() const noexcept { return m_status; }
type& get() const { return get_ref(); }
type& get_ref() const
{
assert(m_status != result_status::idle);
if (m_status == result_status::value) {
assert(m_storage.pointer != nullptr);
assert(reinterpret_cast<size_t>(m_storage.pointer) % alignof(type) == 0);
return *m_storage.pointer;
}
assert(m_status == result_status::exception);
assert(static_cast<bool>(m_storage.exception));
rethrow_exception(m_storage.exception);
}
}; //class producer_context
class consumer_context //负责结果获取
{
private:
enum class consumer_status { idle, await, wait, when_any }; //决定 storage 使用哪个成员
union storage
{
coroutine_handle<void> caller_handle{}; //给 await_ctx 用的
shared_ptr<wait_context> wait_ctx;
shared_ptr<when_any_context> when_any_ctx; //内部自带协程句柄
template<class type, class... argument_type>
static void build(type& o, argument_type&&... arguments) noexcept
{ //注意这种通用构造方式,根据类型分别构造关联上下文,避免 switch
new (addressof(o)) type(forward<argument_type>(arguments)...);
}
template<class type>
static void destroy(type& o) noexcept { o.~type(); } //按类型(即状态)销毁关联上下文
storage() noexcept {}
~storage() noexcept {}
};
private:
consumer_status m_status = consumer_status::idle;
storage m_storage;
public:
~consumer_context() noexcept { clear(); }
void clear() noexcept
{
const auto status = exchange(m_status, consumer_status::idle);
switch (status) {
case consumer_status::idle: return;
case consumer_status::await: storage::destroy(m_storage.caller_handle); return;
case consumer_status::wait: storage::destroy(m_storage.wait_ctx); return;
case consumer_status::when_any: storage::destroy(m_storage.when_any_ctx); return;
}
assert(false);
}
void resume_consumer(result_state_base* self) const noexcept
{
switch (m_status) {
case consumer_status::idle: return;
case consumer_status::await:
{
auto caller_handle = m_storage.caller_handle;
assert(static_cast<bool>(caller_handle));
assert(!caller_handle.done());
return caller_handle(); //coroutine 继续执行,即 resume
}
case consumer_status::wait:
return m_storage.wait_ctx->notify(); //唤醒继续
case consumer_status::when_any:
return m_storage.when_any_ctx->try_resume(self); //尝试赋值并恢复
}
assert(false);
}
void set_await_handle(coroutine_handle<void> caller_handle) noexcept
{
assert(m_status == consumer_status::idle);
m_status = consumer_status::await;
storage::build(m_storage.caller_handle, caller_handle);
}
void set_wait_context(const shared_ptr<wait_context>& wait_ctx) noexcept
{
assert(m_status == consumer_status::idle);
m_status = consumer_status::wait;
storage::build(m_storage.wait_ctx, wait_ctx); //内部调用了 wait_context 的默认移动构造函数
}
void set_when_any_context(const shared_ptr<when_any_context>& when_any_ctx) noexcept
{
assert(m_status == consumer_status::idle);
m_status = consumer_status::when_any;
storage::build(m_storage.when_any_ctx, when_any_ctx);
}
}; //class consumer_context
class result_state_base
{
public:
enum class pc_state { idle, consumer_set, consumer_done, producer_done }; //生产/消费状态
protected:
atomic<pc_state> m_pc_state{ pc_state::idle };
consumer_context m_consumer; //消费者,负责取结果
coroutine_handle<void> m_done_handle; //用在 complete_producer/consumer
void assert_done() const noexcept { assert(m_pc_state.load(memory_order_relaxed) == pc_state::producer_done); }
public:
//三种方式等待生成结果(即 producer_done),期间将状态设为 consumer_set,后面还有一个 wait_for
// cosumer_set 之后,要么 try_rewind_consumer 倒回 idle,要么 complete_producer 唤醒消费者
void wait() //注意这个函数不是 noexcept,原因是 wait_context 可能创建失败
{
const auto state = m_pc_state.load(memory_order_acquire);
if (state == pc_state::producer_done) return; //已经有结果,就不用再等了
auto wait_ctx = make_shared<wait_context>(); //这里有可能会抛出异常。除此行外其他都是 noexcept
m_consumer.set_wait_context(wait_ctx);
auto expected_state = pc_state::idle; //既然结果没出来,应该还是 idle,设成 consumer_set 表示已经在被等待了
const auto idle = m_pc_state.compare_exchange_strong(expected_state, pc_state::consumer_set, memory_order_acq_rel);
if (!idle) { assert_done(); return; } //非 idle,一定是有结果了,同样不等待直接返回
wait_ctx->wait(); //否则就等待,直到 wait_ctx.m_ready
assert_done(); //结束等待时一定是有结果了
}
bool await(coroutine_handle<void> caller_handle) noexcept
{
const auto state = m_pc_state.load(memory_order_acquire);
if (state == pc_state::producer_done) return false;
m_consumer.set_await_handle(caller_handle); //句柄交给 consumer,以便将来可以 resume
auto expected_state = pc_state::idle;
const auto idle = m_pc_state.compare_exchange_strong(expected_state, pc_state::consumer_set, memory_order_acq_rel);
if (!idle) assert_done();
return idle; //为真表示还没有结果,需要挂起
}
pc_state when_any(const shared_ptr<when_any_context>& when_any_state) noexcept
{
const auto state = m_pc_state.load(memory_order_acquire);
if (state == pc_state::producer_done) return state;
m_consumer.set_when_any_context(when_any_state);
auto expected_state = pc_state::idle;
const auto idle = m_pc_state.compare_exchange_strong(expected_state, pc_state::consumer_set, memory_order_acq_rel);
if (!idle) assert_done();
return state; //调用时的起始状态,大概率是空闲
}
void try_rewind_consumer() noexcept //尝试将状态从 consumer_set 倒回 idle,在 when_any_awaitable.await_resume 里调用
{
const auto pc_state = m_pc_state.load(memory_order_acquire);
if (pc_state != pc_state::consumer_set) return; //仅限于 consumer_set 状态
auto expected_consumer_state = pc_state::consumer_set; //准备设置成空闲
const auto consumer = m_pc_state.compare_exchange_strong(expected_consumer_state, pc_state::idle, memory_order_acq_rel);
if (!consumer) { assert_done(); return; } //不成功,一定是 producer_done 了,直接返回
m_consumer.clear(); //成功设成空闲,清除 consumer
}
};
template<class type>
class result_state : public result_state_base
{
private:
producer_context<type> m_producer; //result_state 同时拥有生产者和消费者。消费者在 base 里
//自删除静态函数,用在 complete_producer/consumer 里,当生产者、消费者均完成时做删除
static void delete_self(coroutine_handle<void> done_handle, result_state<type>* state) noexcept
{
if (static_cast<bool>(done_handle)) { assert(done_handle.done()); return done_handle.destroy(); }
delete state; //基本调用情况:参数一为 m_done_handle(默认空);参数二为 this,所以一般是删除自己
}
template<class callable_type>
void from_callable(true_type /*is_void_type*/, callable_type&& callable) { callable(); set_result(); }
template<class callable_type>
void from_callable(false_type /*is_void_type*/, callable_type&& callable) { set_result(callable()); }
public:
template<class... argument_types>
void set_result(argument_types&&... arguments) noexcept(noexcept(type(forward<argument_types>(arguments)...)))
{
m_producer.build_result(forward<argument_types>(arguments)...);
}
void set_exception(const exception_ptr& error) noexcept
{
assert(error != nullptr);
m_producer.build_exception(error);
}
result_status status() const noexcept //主要供 wait_until 使用
{
const auto state = m_pc_state.load(memory_order_acquire);
assert(state != pc_state::consumer_set); //TODO:一定不是 consumer_set。为啥?
if (state == pc_state::idle) return result_status::idle;
return m_producer.status(); //idle、value、exception
}
template<class duration_unit, class ratio>
result_status wait_for(chrono::duration<duration_unit, ratio> duration)
{
const auto state_0 = m_pc_state.load(memory_order_acquire);
if (state_0 == pc_state::producer_done) return m_producer.status();
auto wait_ctx = make_shared<wait_context>();
m_consumer.set_wait_context(wait_ctx);
auto expected_idle_state = pc_state::idle; //空闲,则转换为 consumer_set,表示消费者已在等待
const auto idle_0 = m_pc_state.compare_exchange_strong(expected_idle_state, pc_state::consumer_set, memory_order_acq_rel);
if (!idle_0) { assert_done(); return m_producer.status(); }
const auto ms = chrono::duration_cast<chrono::milliseconds>(duration).count();
if (wait_ctx->wait_for(static_cast<size_t>(ms + 1))) { assert_done(); return m_producer.status(); } //没有超时就完成了
//超时了,状态应该是 consumer_set,将其设为 idle。若不成功,说明生产者设定了结果,直接返回;若成功就不再等了,清理消费者
auto expected_consumer_state = pc_state::consumer_set;
const auto idle_1 = m_pc_state.compare_exchange_strong(expected_consumer_state, pc_state::idle, memory_order_acq_rel);
if (!idle_1) { assert_done(); return m_producer.status(); }
m_consumer.clear(); //消费者将会重新开始等待。这里先清理掉
return result_status::idle;
}
template<class clock, class duration>
result_status wait_until(const chrono::time_point<clock, duration>& timeout_time) //主要用在时间队列里
{
const auto now = clock::now();
if (timeout_time <= now) return status();
const auto diff = timeout_time - now;
return wait_for(diff);
}
type get() { assert_done(); return m_producer.get(); } //result.get 调用之前,已经先 wait 过了,故必定已经有结果
void initialize_producer_from(producer_context<type>& producer_ctx) noexcept { producer_ctx = move(m_producer); }
template<class callable_type>
void from_callable(callable_type&& callable) noexcept //在 result_promise 的 set_from_function 中调用
{
using is_void = is_same<type, void>;
try {
from_callable(is_void{}, forward<callable_type>(callable)); //成功了,设置结果
} catch (...) {
set_exception(current_exception()); //有异常,保存好等待调用时抛出
}
}
//两个 complete 函数用于清理生产者/消费者的状态,一个可以等另一个,两个都完成时就销毁 this
//complete_producer 主要用在 result_publisher、result_coro_promise 里,生成结果后/删除前调用
void complete_producer(result_state_base* self /*for when_any*/, coroutine_handle<void> done_handle = {}) noexcept
{
m_done_handle = done_handle; //存好 handle,将来 complete_consumer 销毁要用到
const auto state_before = m_pc_state.exchange(pc_state::producer_done, memory_order_acq_rel);
assert(state_before != pc_state::producer_done); //producer_done 只在这里设定,不能二次重复
switch (state_before) {
case pc_state::consumer_set: m_consumer.resume_consumer(self); return; //消费者挂起在等,恢复后由它完成自删除
case pc_state::idle: return;
case pc_state::consumer_done: return delete_self(done_handle, this); //complete_consumer 中设定的,值取走了当然要销毁
default: break;
}
assert(false);
}
void complete_consumer() noexcept //主要用在 awaitable::get() 之后,取走结果值之后调用
{
const auto pc_state = m_pc_state.load(memory_order_acquire);
if (pc_state == pc_state::producer_done) return delete_self(m_done_handle, this); //取完值了,要么 producer_done,要么 idle
const auto pc_state1 = m_pc_state.exchange(pc_state::consumer_done, memory_order_acq_rel);
assert(pc_state1 != pc_state::consumer_set); //一定不是 consumer_set。要么流程没走完不会调用,要么上面已经自删除了
if (pc_state1 == pc_state::producer_done) return delete_self(m_done_handle, this);
assert(pc_state1 == pc_state::idle);
}
}; //class result_state
template<class type>
struct consumer_result_state_deleter //用于各类 result
{
void operator()(result_state<type>* state_ptr) { assert(state_ptr != nullptr); state_ptr->complete_consumer(); }
};
template<class type>
using consumer_result_state_ptr = unique_ptr<result_state<type>, consumer_result_state_deleter<type>>;
template<class type>
struct producer_result_state_deleter //用于各类 promise
{
void operator()(result_state<type>* state_ptr) { assert(state_ptr != nullptr); state_ptr->complete_producer(state_ptr); }
};
template<class type>
using producer_result_state_ptr = unique_ptr<result_state<type>, producer_result_state_deleter<type>>;
template<class callable_type>
auto&& bind(callable_type&& callable) { return forward<callable_type>(callable); } //无参,直接调用
template<class callable_type, class... argument_types>
auto bind(callable_type&& callable, argument_types&&... arguments)
{
constexpr static auto inti = is_nothrow_invocable_v<callable_type, argument_types...>;
return [callable = forward<callable_type>(callable),
tuple = make_tuple(forward<argument_types>(arguments)...)]() mutable noexcept(inti) -> decltype(auto) {
return apply(callable, tuple); //带参数调用,参数打包在 tuple 里
};
}
template<class callable_type>
auto&& bind_with_try_catch_impl(true_type, callable_type&& callable) //noexcept,无需捕获异常
{
return forward<callable_type>(callable);
}
template<class callable_type>
auto bind_with_try_catch_impl(false_type, callable_type&& callable)
{
return [callable = forward<callable_type>(callable)]() mutable noexcept { //异常不会溢出,所以是 noexcept
try { callable(); } catch (...) {} }; //捕获异常无需处理,防止溢出即可
}
template<class callable_type>
auto bind_with_try_catch(callable_type&& callable)
{
using is_noexcept = typename is_nothrow_invocable<callable_type>::type;
return bind_with_try_catch_impl(is_noexcept{}, forward<callable_type>(callable));
}
template<class callable_type, class... argument_types>
auto bind_with_try_catch(callable_type&& callable, argument_types&&... arguments)
{
return bind_with_try_catch(bind(forward<callable_type>(callable), forward<argument_types>(arguments)...));
}
class await_via_functor //主要用于 task,以及 resume_on_awaitable,以恢复协程运行
{
private:
await_context* m_ctx;
public:
await_via_functor(await_context* ctx) noexcept : m_ctx(ctx) {}
await_via_functor(await_via_functor&& rhs) noexcept : m_ctx(rhs.m_ctx) { rhs.m_ctx = nullptr; }
~await_via_functor() noexcept
{
if (m_ctx == nullptr) return; //若上下文非空,设定异常后恢复协程执行
m_ctx->set_interrupt(make_exception_ptr(errors::broken_task("result - Associated task was interrupted abnormally")));
m_ctx->resume();
}
void operator()() noexcept
{
assert(m_ctx != nullptr);
const auto await_context = exchange(m_ctx, nullptr); //m_ctx 置空,赋给临时变量
await_context->resume(); //关联协程恢复执行
}
}; //class await_via_functor
class coroutine_handle_functor //用于 task,直接恢复协程执行。await_via_functor 则多了一层包装
{
private:
coroutine_handle<void> m_coro_handle;
public:
coroutine_handle_functor() noexcept : m_coro_handle() {}
coroutine_handle_functor(const coroutine_handle_functor&) = delete;
coroutine_handle_functor& operator=(const coroutine_handle_functor&) = delete;
coroutine_handle_functor(coroutine_handle<void> coro_handle) noexcept : m_coro_handle(coro_handle) {}
coroutine_handle_functor(coroutine_handle_functor&& rhs) noexcept : m_coro_handle(exchange(rhs.m_coro_handle, {})) {}
~coroutine_handle_functor() noexcept { if (static_cast<bool>(m_coro_handle)) m_coro_handle.destroy(); }
void execute_destroy() noexcept { auto coro_handle = exchange(m_coro_handle, {}); coro_handle(); } //有活就干,干完就删
void operator()() noexcept { execute_destroy(); }
}; //class coroutine_handle_functor
template<class type>
class awaitable_base : public suspend_always //awaitable、resolve_awaitable 的基类
{
protected:
consumer_result_state_ptr<type> m_state; //释放会调用 complete_consumer
public:
awaitable_base(consumer_result_state_ptr<type> state) noexcept : m_state(move(state)) {}
awaitable_base(const awaitable_base&) = delete;
awaitable_base(awaitable_base&&) = delete; //移动构造也不允许
}; //class awaitable_base
[[noreturn]] void throw_runtime_shutdown_exception(string_view executor_name) //此函数不会返回
{
const auto error_msg = string(executor_name) + " - shutdown has been called on this executor.";
throw errors::runtime_shutdown(error_msg);
}
string make_executor_worker_name(string_view executor_name)
{
return string(executor_name) + " worker";
}
struct executor_bulk_tag {};
class when_result_helper;
} //namespace details
template<class type>
class awaitable : public details::awaitable_base<type> //result 的 co_await 操作符返回 awaitable,它定义了一系列函数
{
public:
awaitable(details::consumer_result_state_ptr<type> state) noexcept : details::awaitable_base<type>(move(state)) {}
//无 await_ready,视为返回 false,会接着调用 await_suspend;若返回 true 将跳过 await_suspend 直接 await_resume
//await_suspend 返回类型也可以是 void,此时视同返回 true,无条件挂起,交给 caller/resumer
bool await_suspend(coroutine_handle<void> caller_handle) noexcept
{
assert(static_cast<bool>(this->m_state));
return this->m_state->await(caller_handle); //句柄交给 state,返回真表示需要挂起(m_pc_state 在 idle,结果还没出来)
}
type await_resume() //协程恢复后即调用此函数,获取返回值
{
auto state = move(this->m_state); //函数返回后即销毁,期间调用 complete_consumer,因为取完值了
return state->get(); //result_state->get() 再调用 m_producer->get() 获取值
}
}; //class awaitable
template<class type>
class result; //for resolve_awaitable.await_resume
template<class type>
class resolve_awaitable : public details::awaitable_base<type> //用在 result::resolve 里
{
public:
resolve_awaitable(details::consumer_result_state_ptr<type> state) noexcept : details::awaitable_base<type>(move(state)) {}
resolve_awaitable(resolve_awaitable&&) noexcept = delete;
resolve_awaitable(const resolve_awaitable&) noexcept = delete;
bool await_suspend(coroutine_handle<void> caller_handle) noexcept
{
assert(static_cast<bool>(this->m_state));
return this->m_state->await(caller_handle);
}
result<type> await_resume() { return result<type>(move(this->m_state)); } //与 awaitable 相比,返回值多包了一个 result<>
}; //class resolve_awaitable
template<class type>
class result //将 result_state 再进行一个包装,使之能够自动调用 complete_consumer/producer
{
static constexpr auto valid_result_type_v = is_same_v<type, void> || is_nothrow_move_constructible_v<type>;
static_assert(valid_result_type_v, "result<type> - <<type>> should be no-throw-move constructable or void.");
friend class details::when_result_helper; //前面有一个前置声明,定义在后面
private:
details::consumer_result_state_ptr<type> m_state;
void throw_if_empty(const char* message) const { if (static_cast<bool>(!m_state)) throw errors::empty_result(message); }
public:
result() noexcept = default;
result(result&& rhs) noexcept = default;
result(details::consumer_result_state_ptr<type> state) noexcept : m_state(move(state)) {}
result(details::result_state<type>* state) noexcept : m_state(state) {} //直接从原生指针构造,带自定义释放器
result(const result& rhs) = delete;
result& operator=(const result& rhs) = delete;
result& operator=(result&& rhs) noexcept { if (this != &rhs) m_state = move(rhs.m_state); return *this; }
explicit operator bool() const noexcept { return static_cast<bool>(m_state); }
details::result_status status() const { throw_if_empty("result::status() - result is empty."); return m_state->status(); }
void wait() const { throw_if_empty("result::wait() - result is empty."); m_state->wait(); }
template<class duration_type, class ratio_type>
details::result_status wait_for(chrono::duration<duration_type, ratio_type> duration) const
{
throw_if_empty("result::wait_for() - result is empty.");
return m_state->wait_for(duration);
}
template<class clock_type, class duration_type>
details::result_status wait_until(chrono::time_point<clock_type, duration_type> timeout_time) const
{
throw_if_empty("result::wait_until() - result is empty.");
return m_state->wait_until(timeout_time);
}
//取值方式支持 wait、await 模式
type get()
{
throw_if_empty("result::get() - result is empty.");
auto state = move(m_state); //移动交给一个临时变量,返回即自动释放并调用 complete_consumer
state->wait(); //调用 wait 函数,无需等待就直接返回,需要就等到结果生成
return state->get(); //最后由 m_producer.get 来取得结果
}
auto operator co_await() //重载运算符,使用 awaitable<type>,否则自己要定义 await_* 系列函数
{
throw_if_empty("result::operator co_await() - result is empty.");
return awaitable<type>{ move(m_state) }; //会随后调用此类型的 await_* 系列函数
//如果结果已有,立即恢复;否则挂起等待结果生成,生成时会通知已经挂起的协程恢复,去取值或者抛出异常
}
auto resolve()
{
throw_if_empty("result::resolve() - result is empty.");
return resolve_awaitable<type>{ move(m_state) };
}
}; //class result
template<class type>
class lazy_result;
namespace details
{
struct lazy_final_awaiter : public suspend_always
{
template<class promise_type>
coroutine_handle<void> await_suspend(coroutine_handle<promise_type> handle) noexcept
{
return handle.promise().resume_caller(); //调用的应当是下面的函数,编译期会检查
}
};
class lazy_result_state_base //作为 promise 使用
{
protected:
coroutine_handle<void> m_caller_handle;
public:
coroutine_handle<void> resume_caller() const noexcept { return m_caller_handle; }
coroutine_handle<void> await(coroutine_handle<void> caller_handle) noexcept
{
m_caller_handle = caller_handle;
return coroutine_handle<lazy_result_state_base>::from_promise(*this);
}
};
template<class type>
class lazy_result_state : public lazy_result_state_base
{
private:
producer_context<type> m_producer;
public:
lazy_result<type> get_return_object() noexcept
{
const auto self_handle = coroutine_handle<lazy_result_state>::from_promise(*this);
return lazy_result<type>(self_handle);
}
void unhandled_exception() noexcept { m_producer.build_exception(current_exception()); }
suspend_always initial_suspend() const noexcept { return {}; } //lazy 总是挂起
lazy_final_awaiter final_suspend() const noexcept { return {}; } //重定义 await_suspend,返回协程句柄
template<class... argument_types>
void set_result(argument_types&&... arguments) noexcept(noexcept(type(forward<argument_types>(arguments)...)))
{
m_producer.build_result(forward<argument_types>(arguments)...);
}
result_status status() const noexcept { return m_producer.status(); }
type get() { return m_producer.get(); }
};
} //namespace details
template<class type>
class lazy_awaitable
{
private:
const coroutine_handle<details::lazy_result_state<type>> m_state;
public:
lazy_awaitable(coroutine_handle<details::lazy_result_state<type>> state) noexcept
: m_state(state)
{
assert(static_cast<bool>(state));
}
lazy_awaitable(const lazy_awaitable&) = delete;
lazy_awaitable(lazy_awaitable&&) = delete;
~lazy_awaitable() noexcept { auto state = m_state; state.destroy(); } //m_state 是 const,不能改
bool await_ready() const noexcept { return m_state.done(); }
coroutine_handle<void> await_suspend(coroutine_handle<void> caller_handle) noexcept
{
return m_state.promise().await(caller_handle); //调用上面的 await,返回协程句柄。有效即挂起
}
type await_resume() { return m_state.promise().get(); }
}; //class lazy_awaitable
template<class type>
class lazy_resolve_awaitable
{
private:
coroutine_handle<details::lazy_result_state<type>> m_state;
public:
lazy_resolve_awaitable(coroutine_handle<details::lazy_result_state<type>> state) noexcept
: m_state(state)
{
assert(static_cast<bool>(state));
}
lazy_resolve_awaitable(const lazy_resolve_awaitable&) = delete;
lazy_resolve_awaitable(lazy_resolve_awaitable&&) = delete;
~lazy_resolve_awaitable() noexcept { if (static_cast<bool>(m_state)) m_state.destroy(); }
bool await_ready() const noexcept { return m_state.done(); }
coroutine_handle<void> await_suspend(coroutine_handle<void> caller_handle) noexcept
{
return m_state.promise().await(caller_handle);
}
lazy_result<type> await_resume() { return { exchange(m_state, {}) }; } //m_state 非 const,可以交换
}; //class lazy_resolve_awaitable
template<class type>
class lazy_result //启动关联延迟任务,并将其值交给调用者
{
private:
coroutine_handle<details::lazy_result_state<type>> m_state;
void throw_if_empty(const char* err_msg) const
{
if (!static_cast<bool>(m_state)) throw errors::empty_result(err_msg);
}
result<type> run_impl() //TODO:咋变成 result<type> 的?
{
lazy_result self(move(*this)); //比直接用 co_return/await *this,销毁更快
co_return co_await self; //co_wait 返回 lazy_awaitable,随后 co_return 使用 lazy_result_state
}
public:
lazy_result() noexcept = default;
lazy_result(lazy_result&& rhs) noexcept : m_state(exchange(rhs.m_state, {})) {}
lazy_result(coroutine_handle<details::lazy_result_state<type>> state) noexcept : m_state(state) {}
~lazy_result() noexcept { if (static_cast<bool>(m_state)) m_state.destroy(); }
lazy_result& operator=(lazy_result&& rhs) noexcept
{
if (&rhs == this) return *this;
if (static_cast<bool>(m_state)) m_state.destroy();
m_state = exchange(rhs.m_state, {});
return *this;
}
explicit operator bool() const noexcept { return static_cast<bool>(m_state); }
details::result_status status() const { throw_if_empty("."); return m_state.promise().status(); } //归于生产者状态
auto operator co_await() { throw_if_empty("."); return lazy_awaitable<type>{ exchange(m_state, {}) }; }
auto resolve() { throw_if_empty("."); return lazy_resolve_awaitable<type>{ exchange(m_state, {}) }; }
result<type> run() { throw_if_empty("."); return run_impl(); } //内联运行关联任务(转换成迫切任务)
}; //class lazy_result
namespace details
{
//return_value_struct 其实是一个包装,底类型 type/void,再包装成 derived_type,分别定义了两个 return,以支持 co_return
template<class derived_type, class type>
struct return_value_struct
{
template<class return_type>
void return_value(return_type&& value)
{
auto self = static_cast<derived_type*>(this);
self->set_result(forward<return_type>(value)); //从这里调用 result_coro_promise->set_result
}
};
template<class derived_type>
struct return_value_struct<derived_type, void>
{
void return_void() noexcept
{
auto self = static_cast<derived_type*>(this);
self->set_result();
}
};
struct initialy_resumed_promise
{
suspend_never initial_suspend() const noexcept { return {}; } //除 lazy_result 外都是立即执行不挂起
};
//result_publisher 只用在 result_coro_promise 作为 final_suspend 的返回值
struct result_publisher : public suspend_always //挂起,caller/resumer 销毁生产者后再恢复
{
template<class promise_type> //带 promise_type 是因为要使用 handle.promise()
void await_suspend(coroutine_handle<promise_type> handle) const noexcept //void 视同 true
{
handle.promise().complete_producer(handle); //promise 类型是 result_coro_promise
}
};
template<class type> //result_coro_promise<type> 正好继承自 type,这种类型容易绕晕
struct result_coro_promise : public return_value_struct<result_coro_promise<type>, type>
{
private:
result_state<type> m_result_state;
public:
template<class... argument_types> //从 return_value/void 函数调用过来
void set_result(argument_types&&... arguments) noexcept(noexcept(type(forward<argument_types>(arguments)...)))
{
this->m_result_state.set_result(forward<argument_types>(arguments)...); //最后交给 m_producer.build_result
}
void unhandled_exception() noexcept { this->m_result_state.set_exception(current_exception()); }
result<type> get_return_object() noexcept { return { &m_result_state }; } //result<type> 都是从这里构建而得
//从 final_suspend 返回的 result_publisher.await_suspend 调用过来,确保生产者销毁完毕
void complete_producer(coroutine_handle<void> done_handle) noexcept
{
this->m_result_state.complete_producer(&m_result_state, done_handle);
}
result_publisher final_suspend() const noexcept { return {}; } //协程结束释放前是否要挂起
};
//initialy_resumed_promise 实现了 initial_suspend
//result_coro_promise 实现了 get_return_object、final_suspend、unhandled_exception
template<class return_type>
struct initialy_resumed_result_promise : public initialy_resumed_promise, public result_coro_promise<return_type> {};
//return_value_struct 基类实现了 return_value/return_void
//lazy_result_state 实现了 get_return_object、initial_suspend、final_suspend、unhandled_exception
template<class type>
struct lazy_promise : lazy_result_state<type>, public return_value_struct<lazy_promise<type>, type> {};
} //namespace details
template<class type>
class shared_result;
namespace details
{
struct shared_await_context
{
shared_await_context* next = nullptr;
coroutine_handle<void> caller_handle;
};
class shared_result_state_base
{
protected:
atomic_bool m_ready{ false };
mutable mutex m_lock;
shared_await_context* m_awaiters = nullptr; //多个等待者排成单链
optional<condition_variable> m_condition;
void await_impl(unique_lock<mutex>& lock, shared_await_context& awaiter) noexcept
{
assert(lock.owns_lock());
if (m_awaiters == nullptr) { m_awaiters = &awaiter; return; }
awaiter.next = m_awaiters;
m_awaiters = &awaiter;
}
void wait_impl(unique_lock<mutex>& lock) noexcept
{
assert(lock.owns_lock());
if (!m_condition.has_value()) m_condition.emplace();
m_condition.value().wait(lock, [this] { return m_ready.load(memory_order_relaxed); });
}
bool wait_for_impl(unique_lock<mutex>& lock, chrono::milliseconds ms) noexcept
{
assert(lock.owns_lock());
if (!m_condition.has_value()) m_condition.emplace();
return m_condition.value().wait_for(lock, ms, [this] { return m_ready.load(memory_order_relaxed); });
}
public:
void complete_producer() noexcept
{
shared_await_context* awaiters;
{
unique_lock<mutex> lock(m_lock);
awaiters = exchange(m_awaiters, nullptr); //先把异步等待序列提出来
m_ready.store(true, memory_order_release); //确保完成此前所有写操作
if (m_condition.has_value()) m_condition.value().notify_all(); //通知所有 wait
}
while (awaiters != nullptr) {
const auto next = awaiters->next;
awaiters->caller_handle(); //恢复所有 await
awaiters = next;
}
}
bool await(shared_await_context& awaiter) noexcept
{
if (m_ready.load(memory_order_acquire)) return false; //已经就绪,不用等了
{
unique_lock<mutex> lock(m_lock);
if (m_ready.load(memory_order_acquire)) return false;
await_impl(lock, awaiter);
}
return true;
}
void wait() noexcept
{
if (m_ready.load(memory_order_acquire)) return;
{
unique_lock<mutex> lock(m_lock);
if (m_ready.load(memory_order_acquire)) return;
wait_impl(lock);
}
}
}; //class shared_result_state_base
template<class type>
class shared_result_state final : public shared_result_state_base
{
private:
producer_context<type> m_producer;
void assert_done() const noexcept
{
assert(m_ready.load(memory_order_acquire));
assert(m_producer.status() != result_status::idle);
}
public:
result_status status() const noexcept
{
if (!m_ready.load(memory_order_acquire)) return result_status::idle;
return m_producer.status();
}
template<class duration_unit, class ratio>
result_status wait_for(chrono::duration<duration_unit, ratio> duration) noexcept
{
if (m_ready.load(memory_order_acquire)) return m_producer.status();
const auto ms = chrono::duration_cast<chrono::milliseconds>(duration) + chrono::milliseconds(1);
unique_lock<mutex> lock(m_lock);
if (m_ready.load(memory_order_acquire)) return m_producer.status();
const auto ready = wait_for_impl(lock, ms);
if (ready) { assert_done(); return m_producer.status(); } //结束等待,未超时
lock.unlock();
return result_status::idle; //超时了
}
template<class clock, class duration>
result_status wait_until(const chrono::time_point<clock, duration>& timeout_time) noexcept
{
const auto now = clock::now();
if (timeout_time <= now) return status();
const auto diff = timeout_time - now;
return wait_for(diff);
}
template<class... argument_types>
void set_result(argument_types&&... arguments) noexcept(noexcept(type(forward<argument_types>(arguments)...)))
{
m_producer.build_result(forward<argument_types>(arguments)...);
}
add_lvalue_reference_t<type> get() { return m_producer.get_ref(); } //共享结果,返回引用
void unhandled_exception() noexcept { m_producer.build_exception(current_exception()); }
}; //class shared_result_state
struct shared_result_publisher : public suspend_always
{
template<class promise_type>
bool await_suspend(coroutine_handle<promise_type> handle) const noexcept
{
handle.promise().complete_producer(); //shared_result_promise
return false; //共享结果不再挂起(非共享结果要挂起等待取出)
}
};
template<class type>
class shared_result_promise : public return_value_struct<shared_result_promise<type>, type>
{
private:
const shared_ptr<shared_result_state<type>> m_state = make_shared<shared_result_state<type>>();
public:
template<class... argument_types>
void set_result(argument_types&&... arguments) noexcept(noexcept(type(forward<argument_types>(arguments)...)))
{
m_state->set_result(forward<argument_types>(arguments)...);
}
void unhandled_exception() noexcept { m_state->unhandled_exception(); }
shared_result<type> get_return_object() noexcept { return shared_result<type> {m_state}; }
suspend_never initial_suspend() const noexcept { return {}; }
shared_result_publisher final_suspend() const noexcept { return {}; }
void complete_producer() noexcept { m_state->complete_producer(); }
};
template<class type>
class shared_awaitable_base : public suspend_always
{
protected:
shared_ptr<shared_result_state<type>> m_state;
public:
shared_awaitable_base(const shared_ptr<shared_result_state<type>>& state) noexcept : m_state(state) {}
shared_awaitable_base(const shared_awaitable_base&) = delete;
shared_awaitable_base(shared_awaitable_base&&) = delete;
};
struct shared_result_tag {};
} //namespace details
template<class type>
class shared_awaitable : public details::shared_awaitable_base<type>
{
private:
details::shared_await_context m_await_ctx;
public:
shared_awaitable(const shared_ptr<details::shared_result_state<type>>& state) noexcept
: details::shared_awaitable_base<type>(state) {}
bool await_suspend(coroutine_handle<void> caller_handle) noexcept
{
assert(static_cast<bool>(this->m_state));
this->m_await_ctx.caller_handle = caller_handle;
return this->m_state->await(m_await_ctx);
}
add_lvalue_reference_t<type> await_resume() { return this->m_state->get(); }
}; //class shared_awaitable
template<class type>
class shared_resolve_awaitable : public details::shared_awaitable_base<type>
{
private:
details::shared_await_context m_await_ctx;
public:
shared_resolve_awaitable(const shared_ptr<details::shared_result_state<type>>& state) noexcept
: details::shared_awaitable_base<type>(state) {}
bool await_suspend(coroutine_handle<void> caller_handle) noexcept
{
assert(static_cast<bool>(this->m_state));
this->m_await_ctx.caller_handle = caller_handle;
return this->m_state->await(m_await_ctx);
}
shared_result<type> await_resume() { return shared_result<type>(move(this->m_state)); }
}; //class shared_resolve_awaitable
template<class type>
class shared_result //允许多个消费者取同一个结果
{
private:
shared_ptr<details::shared_result_state<type>> m_state; //共享指针
static shared_result<type> make_shared_result(details::shared_result_tag, result<type> result)
{
co_return co_await result; //co_await 返回 type,co_return 从 shared_result_promise 将其转换成 shared_result
}
void throw_if_empty(const char* message) const
{
if (!static_cast<bool>(m_state)) throw errors::empty_result(message);
}
public:
shared_result() noexcept = default;
~shared_result() noexcept = default;
shared_result(shared_ptr<details::shared_result_state<type>> state) noexcept : m_state(move(state)) {}
shared_result(result<type> rhs)
{
if (!static_cast<bool>(rhs)) return;
*this = make_shared_result({}, move(rhs));
}
shared_result(const shared_result& rhs) noexcept = default;
shared_result(shared_result&& rhs) noexcept = default;
shared_result& operator=(const shared_result& rhs) noexcept
{
if (this != &rhs && m_state != rhs.m_state) m_state = rhs.m_state;
return *this;
}
shared_result& operator=(shared_result&& rhs) noexcept
{
if (this != &rhs && m_state != rhs.m_state) m_state = move(rhs.m_state);
return *this;
}
operator bool() const noexcept { return static_cast<bool>(m_state.get()); }
details::result_status status() const { throw_if_empty("."); return m_state->status(); }
void wait() { throw_if_empty("."); m_state->wait(); }
template<class duration_type, class ratio_type>
details::result_status wait_for(chrono::duration<duration_type, ratio_type> duration)
{
throw_if_empty(".");
return m_state->wait_for(duration);
}
template<class clock_type, class duration_type>
details::result_status wait_until(chrono::time_point<clock_type, duration_type> timeout_time)
{
throw_if_empty(".");
return m_state->wait_until(timeout_time);
}
add_lvalue_reference_t<type> get()
{
throw_if_empty(".");
m_state->wait();
return m_state->get();
}
auto operator co_await() { throw_if_empty("."); return shared_awaitable<type> {m_state}; }
auto resolve() { throw_if_empty("."); return shared_resolve_awaitable<type> {m_state}; }
}; //class shared_result
} //namespace concurrencpp
//返回类型指定 promise_type,才能在 executor 中使用 co_return
template<class type, class... arguments>
struct coroutine_traits<concurrencpp::result<type>, arguments...>
{
using promise_type = concurrencpp::details::initialy_resumed_result_promise<type>;
};
template<class type, class... arguments>
struct coroutine_traits<::concurrencpp::lazy_result<type>, arguments...>
{
using promise_type = concurrencpp::details::lazy_promise<type>;
};
template<class type>
struct coroutine_traits<::concurrencpp::shared_result<type>, concurrencpp::details::shared_result_tag, concurrencpp::result<type>>
{
using promise_type = concurrencpp::details::shared_result_promise<type>;
};
namespace concurrencpp
{
template<class type>
class result_promise //生产者的一个包装,典型应用场景是与第三方代码交互
{
static constexpr auto valid_result_type_v = is_same_v<type, void> || is_nothrow_move_constructible_v<type>;
static_assert(valid_result_type_v, "result_promise<type> - <<type>> should be now-throw-move constructable or void.");
private:
details::producer_result_state_ptr<type> m_state;
bool m_result_retrieved;
void throw_if_empty(const char* message) const { if (!static_cast<bool>(m_state)) throw errors::empty_result_promise(message); }
void break_task_if_needed() noexcept //已取结果却还没销毁,则抛出一个中断异常
{
if (!static_cast<bool>(m_state)) return; //还没出结果
if (!m_result_retrieved) return; //有结果但还没取。取出就应该销毁/置空 m_state
auto exception_ptr = make_exception_ptr(errors::broken_task("result - Associated task was interrupted abnormally"));
m_state->set_exception(exception_ptr);
m_state.reset(); //重置指针。原有指针通过 deleter 销毁,并调用 complete_producer
}
public:
result_promise() : m_state(new details::result_state<type>()), m_result_retrieved(false) {}
result_promise(result_promise&& rhs) noexcept : m_state(move(rhs.m_state)), m_result_retrieved(rhs.m_result_retrieved) {}
result_promise(const result_promise&) = delete;
result_promise& operator=(const result_promise&) = delete;
~result_promise() noexcept { break_task_if_needed(); }
result_promise& operator=(result_promise&& rhs) noexcept
{
if (this != &rhs) {
break_task_if_needed();
m_state = move(rhs.m_state);
m_result_retrieved = rhs.m_result_retrieved;
}
return *this;
}
explicit operator bool() const noexcept { return static_cast<bool>(m_state); }
template<class... argument_types>
void set_result(argument_types&&... arguments)
{
constexpr auto is_constructable = is_constructible_v<type, argument_types...> || is_same_v<void, type>;
static_assert(is_constructable, "result_promise::set_result() - <<type>> is not constructable from <<arguments...>>");
throw_if_empty("result_promise::set_result() - empty result_promise.");
m_state->set_result(forward<argument_types>(arguments)...);
m_state.reset(); //这里会调用 complete_producer
}
void set_exception(exception_ptr exception_ptr)
{
throw_if_empty("result_promise::set_exception() - empty result_promise.");
if (!static_cast<bool>(exception_ptr))
throw invalid_argument("result_promise::set_exception() - exception pointer is null.");
m_state->set_exception(exception_ptr);
m_state.reset();
}
template<class callable_type, class... argument_types>
void set_from_function(callable_type&& callable, argument_types&&... args) noexcept
{
constexpr auto is_invokable = is_invocable_r_v<type, callable_type, argument_types...>;
static_assert(is_invokable, "result_promise::set_from_function() - function(args...) is not invokable or its return type can't be used to construct <<type>>");
throw_if_empty("result_promise::set_from_function() - empty result_promise.");
m_state->from_callable(details::bind(forward<callable_type>(callable), forward<argument_types>(args)...));
m_state.reset();
}
result<type> get_result()
{
throw_if_empty("result::get() - result is empty.");
if (m_result_retrieved)
throw errors::result_already_retrieved("result_promise::get_result() - result was already retrieved.");
m_result_retrieved = true;
return result<type>(m_state.get());
}
}; //class result_promise
namespace details
{
struct vtable //用在 task 类中
{
void (*move_destroy_fn)(void* src, void* dst) noexcept;
void (*execute_destroy_fn)(void* target);
void (*destroy_fn)(void* target) noexcept;
vtable(const vtable&) noexcept = default;
constexpr vtable() noexcept : move_destroy_fn(nullptr), execute_destroy_fn(nullptr), destroy_fn(nullptr) {}
constexpr vtable(decltype(move_destroy_fn) move_destroy_fn, decltype(execute_destroy_fn) execute_destroy_fn,
decltype(destroy_fn) destroy_fn) noexcept : move_destroy_fn(move_destroy_fn),
execute_destroy_fn(execute_destroy_fn), destroy_fn(destroy_fn)
{
}
//普通可拷贝可析构:没有自定义移动销毁函数
static constexpr bool trivially_copiable_destructible(decltype(move_destroy_fn) move_fn) noexcept { return move_fn == nullptr; }
//普通可析构:没有自定义销毁函数
static constexpr bool trivially_destructable(decltype(destroy_fn) destroy_fn) noexcept { return destroy_fn == nullptr; }
}; //struct vtable
template<class callable_type>
class callable_vtable //用在 task 类中
{
private:
static callable_type* inline_ptr(void* src) noexcept { return static_cast<callable_type*>(src); } //单层转换
static callable_type* allocated_ptr(void* src) noexcept { return *static_cast<callable_type**>(src); } //双层脱壳
static callable_type*& allocated_ref_ptr(void* src) noexcept { return *static_cast<callable_type**>(src); } //双层脱壳引用
static void move_destroy_inline(void* src, void* dst) noexcept //单层指针
{
auto callable_ptr = inline_ptr(src);
new (dst) callable_type(move(*callable_ptr)); //根据指针所指对象移动构造
callable_ptr->~callable_type(); //再调用指针所指对象的析构函数
}
static void move_destroy_allocated(void* src, void* dst) noexcept //双层指针
{
auto callable_ptr = exchange(allocated_ref_ptr(src), nullptr); //src交换为空,callable_ptr指向实际目标
new (dst) callable_type* (callable_ptr); //直接填入实际目标指针
}
static void execute_destroy_inline(void* target)
{
auto callable_ptr = inline_ptr(target);
(*callable_ptr)();
callable_ptr->~callable_type(); //调用刚刚执行完的执行体析构函数
}
static void execute_destroy_allocated(void* target)
{
auto callable_ptr = allocated_ptr(target); //双层脱壳为单层指针
(*callable_ptr)();
delete callable_ptr; //第二层指针直接销毁释放
}
static void destroy_inline(void* target) noexcept
{
auto callable_ptr = inline_ptr(target);
callable_ptr->~callable_type();
}
static void destroy_allocated(void* target) noexcept
{
auto callable_ptr = allocated_ptr(target);
delete callable_ptr;
}
static constexpr vtable make_vtable() noexcept
{
void (*move_destroy_fn)(void* src, void* dst) noexcept = nullptr;
void (*destroy_fn)(void* target) noexcept = nullptr;
if constexpr (is_trivially_copy_constructible_v<callable_type> && is_trivially_destructible_v<callable_type>)
move_destroy_fn = nullptr;
else //有自定义的移动或析构函数
move_destroy_fn = move_destroy;
if constexpr (is_trivially_destructible_v<callable_type>)
destroy_fn = nullptr;
else //有自定义析构函数
destroy_fn = destroy;
return vtable(move_destroy_fn, execute_destroy, destroy_fn);
}
template<class passed_callable_type>
static void build_inlinable(void* dst, passed_callable_type&& callable)
{
new (dst) callable_type(forward<passed_callable_type>(callable)); //直接构造对象
}
template<class passed_callable_type>
static void build_allocated(void* dst, passed_callable_type&& callable)
{
auto new_ptr = new callable_type(forward<passed_callable_type>(callable));
new (dst) callable_type* (new_ptr); //先分配构造好对象,再分配其指针,将对象指针填入
}
public:
static constexpr bool is_inlinable() noexcept //1.移动构造无异常;2.双层指针内存占用不会越界
{
return is_nothrow_move_constructible_v<callable_type> && sizeof(callable_type) <= 64 - sizeof(void*);
}
template<class passed_callable_type>
static void build(void* dst, passed_callable_type&& callable)
{
if (is_inlinable()) return build_inlinable(dst, forward<passed_callable_type>(callable));
build_allocated(dst, forward<passed_callable_type>(callable));
}
static void move_destroy(void* src, void* dst) noexcept
{
assert(src != nullptr && dst != nullptr);
if (is_inlinable()) return move_destroy_inline(src, dst);
return move_destroy_allocated(src, dst);
}
static void execute_destroy(void* target)
{
assert(target != nullptr);
if (is_inlinable()) return execute_destroy_inline(target);
return execute_destroy_allocated(target);
}
static void destroy(void* target) noexcept
{
assert(target != nullptr);
if (is_inlinable()) return destroy_inline(target);
return destroy_allocated(target);
}
static constexpr callable_type* as(void* src) noexcept
{
if (is_inlinable()) return inline_ptr(src);
return allocated_ptr(src);
}
static constexpr inline vtable s_vtable = make_vtable();
}; //class callable_vtable
} //namespace details
class task
{
private:
alignas(max_align_t) std::byte m_buffer[64 - sizeof(void*)];
const details::vtable* m_vtable;
void build(task&& rhs) noexcept
{
m_vtable = exchange(rhs.m_vtable, nullptr); //先把 vtable 交换过来,再对 buffer 区别操作
if (m_vtable == nullptr) return;
//coroutine_handle_functor 和 await_via_functor 有自定义移动和析构函数,所以要使用 move_destroy
if (contains<details::coroutine_handle_functor>(m_vtable))
return details::callable_vtable<details::coroutine_handle_functor>::move_destroy(rhs.m_buffer, m_buffer);
if (contains<details::await_via_functor>(m_vtable))
return details::callable_vtable<details::await_via_functor>::move_destroy(rhs.m_buffer, m_buffer);
const auto move_destroy_fn = m_vtable->move_destroy_fn;
if (details::vtable::trivially_copiable_destructible(move_destroy_fn)) { //若没有定义 move_destroy_fn
memcpy(m_buffer, rhs.m_buffer, 64 - sizeof(void*)); //直接进行移动操作
return;
}
move_destroy_fn(rhs.m_buffer, m_buffer); //否则使用自定义函数进行移动
}
void build(coroutine_handle<void> coro_handle) noexcept
{
build(details::coroutine_handle_functor{ coro_handle }); //协程句柄包装成函数子,作为 callable_type
}
template<class callable_type>
void build(callable_type&& callable)
{
using decayed_type = typename std::decay_t<callable_type>;
details::callable_vtable<decayed_type>::build(m_buffer, forward<callable_type>(callable)); //buffer 对应的 dst
m_vtable = &details::callable_vtable<decayed_type>::s_vtable;
}
template<class callable_type>
static bool contains(const details::vtable* const vtable) noexcept
{
return vtable == &details::callable_vtable<callable_type>::s_vtable; //是否包含/实现了 callable_type
}
public:
task() noexcept : m_buffer(), m_vtable(nullptr) {}
task(task&& rhs) noexcept { build(move(rhs)); }
task(const task& rhs) = delete;
task& operator=(const task&& rhs) = delete;
template<class callable_type>
task(callable_type&& callable) { build(forward<callable_type>(callable)); }
~task() noexcept { clear(); }
void operator()()
{
const auto vtable = exchange(m_vtable, nullptr); //交换出来,执行后返回都会销毁
if (vtable == nullptr) return;
if (contains<details::coroutine_handle_functor>(vtable)) //这两种执行子,就调 execute_destroy 函数
return details::callable_vtable<details::coroutine_handle_functor>::execute_destroy(m_buffer);
if (contains<details::await_via_functor>(vtable))
return details::callable_vtable<details::await_via_functor>::execute_destroy(m_buffer);
vtable->execute_destroy_fn(m_buffer); //都不是,调用函数。这个函数不可能空
}
task& operator=(task&& rhs) noexcept
{
if (this == &rhs) return *this;
clear();
build(move(rhs));
return *this;
}
void clear() noexcept
{
if (m_vtable == nullptr) return;
const auto vtable = exchange(m_vtable, nullptr);
if (contains<details::coroutine_handle_functor>(vtable))
return details::callable_vtable<details::coroutine_handle_functor>::destroy(m_buffer);
if (contains<details::await_via_functor>(vtable))
return details::callable_vtable<details::await_via_functor>::destroy(m_buffer);
auto destroy_fn = vtable->destroy_fn;
if (details::vtable::trivially_destructable(destroy_fn)) return; //可能为空,是则直接返回
destroy_fn(m_buffer);
}
explicit operator bool() const noexcept { return m_vtable != nullptr; }
template<class callable_type>
bool contains() const noexcept
{
using decayed_type = typename std::decay_t<callable_type>;
if constexpr (is_same_v<decayed_type, coroutine_handle<void>>)
return contains<details::coroutine_handle_functor>();
return m_vtable == &details::callable_vtable<decayed_type>::s_vtable;
}
}; //class task
struct executor_tag {};
class executor
{
private:
template<class return_type, class executor_type, class callable_type, class... argument_types>
static result<return_type> submit_bridge(executor_tag, executor_type&, callable_type callable, argument_types... arguments)
{
co_return callable(arguments...);
}
template<class callable_type, typename return_type = invoke_result_t<callable_type>>
static result<return_type> bulk_submit_bridge(details::executor_bulk_tag, vector<task>& accumulator, callable_type callable)
{
co_return callable(); //没用到 accumulator
}
protected:
template<class executor_type, class callable_type, class... argument_types>
static void do_post(executor_type& executor_ref, callable_type&& callable, argument_types&&... arguments)
{
static_assert(is_invocable_v<callable_type, argument_types...>,
"executor::post - <<callable_type>> is not invokable with <<argument_types...>>");
executor_ref.enqueue(details::bind_with_try_catch(forward<callable_type>(callable), forward<argument_types>(arguments)...));
}
template<class executor_type, class callable_type, class... argument_types>
static auto do_submit(executor_type& executor_ref, callable_type&& callable, argument_types&&... arguments)
{
static_assert(is_invocable_v<callable_type, argument_types...>,
"executor::submit - <<callable_type>> is not invokable with <<argument_types...>>");
using return_type = typename std::invoke_result_t<callable_type, argument_types...>;
return submit_bridge<return_type>({}, executor_ref, forward<callable_type>(callable), forward<argument_types>(arguments)...);
}
template<class executor_type, class callable_type>
static void do_bulk_post(executor_type& executor_ref, span<callable_type> callable_list)
{
assert(!callable_list.empty());
vector<task> tasks;
tasks.reserve(callable_list.size());
for (auto& callable : callable_list) tasks.emplace_back(details::bind_with_try_catch(move(callable)));
span<task> span = tasks;
executor_ref.enqueue(span);
}
template<class executor_type, class callable_type, class return_type = invoke_result_t<callable_type>>
static vector<result<return_type>> do_bulk_submit(executor_type& executor_ref, span<callable_type> callable_list)
{
vector<task> accumulator;
accumulator.reserve(callable_list.size());
vector<result<return_type>> results;
results.reserve(callable_list.size());
for (auto& callable : callable_list) results.emplace_back(bulk_submit_bridge<callable_type>({}, accumulator, move(callable)));
assert(!accumulator.empty()); //TODO:上面的任务是怎么加到这里来的?bulk_submit_bridge 似乎没用到 accumulator
span<task> span = accumulator;
executor_ref.enqueue(span);
return results;
}
public:
executor(string_view name) : name(name) {}
virtual ~executor() noexcept = default;
const string name;
virtual void enqueue(task task) = 0;
virtual void enqueue(span<task> tasks) = 0;
virtual int max_concurrency_level() const noexcept = 0;
virtual bool shutdown_requested() const noexcept = 0;
virtual void shutdown() noexcept = 0;
template<class callable_type, class... argument_types>
void post(callable_type&& callable, argument_types&&... arguments) //不带返回值的任务,enqueue
{
return do_post(*this, forward<callable_type>(callable), forward<argument_types>(arguments)...);
}
template<class callable_type, class... argument_types>
auto submit(callable_type&& callable, argument_types&&... arguments) //带返回值的任务,submit_bridge
{
return do_submit(*this, forward<callable_type>(callable), forward<argument_types>(arguments)...);
}
template<class callable_type>
void bulk_post(span<callable_type> callable_list) { return do_bulk_post(*this, callable_list); }
template<class callable_type, class return_type = invoke_result_t<callable_type>>
vector<result<return_type>> bulk_submit(span<callable_type> callable_list) { return do_bulk_submit(*this, callable_list); }
}; //class executor
template<class sequence_type>
struct when_any_result
{
size_t index;
sequence_type results;
when_any_result() noexcept : index(static_cast<size_t>(-1)) {}
template<class... result_types>
when_any_result(size_t index, result_types&&... results) noexcept :
index(index), results(forward<result_types>(results)...)
{
}
when_any_result(when_any_result&&) noexcept = default;
when_any_result& operator=(when_any_result&&) noexcept = default;
};
namespace details
{
class when_result_helper
{
private:
template<class type>
static void throw_if_empty_single(const char* error_message, const result<type>& result)
{
if (!static_cast<bool>(result)) throw errors::empty_result(error_message);
}
static void throw_if_empty_impl(const char* error_message) noexcept { (void)error_message; }
template<class type, class... result_types>
static void throw_if_empty_impl(const char* error_message, const result<type>& result, result_types&&... results)
{
throw_if_empty_single(error_message, result); //模板递归,依次扔出异常
throw_if_empty_impl(error_message, forward<result_types>(results)...);
}
template<class type>
static result_state_base* get_state_base(result<type>& result) noexcept { return result.m_state.get(); }
template<size_t... is, typename tuple_type>
static result_state_base* at_impl(index_sequence<is...>, tuple_type& tuple, size_t n) noexcept
{
result_state_base* bases[] = { get_state_base(get<is>(tuple))... }; //构造 result_state_base 数组
return bases[n]; //返回其中第 n 项
}
public:
template<typename tuple_type>
static result_state_base* at(tuple_type& tuple, size_t n) noexcept //获取 tuple 中的第 n 项
{
auto seq = make_index_sequence<tuple_size<tuple_type>::value>();
return at_impl(seq, tuple, n);
}
template<class... result_types>
static void throw_if_empty_tuple(const char* error_message, result_types&&... results) //遍历查空,有则异常
{
throw_if_empty_impl(error_message, forward<result_types>(results)...);
}
template<class iterator_type>
static void throw_if_empty_range(const char* error_message, iterator_type begin, iterator_type end)
{
for (; begin != end; ++begin) throw_if_empty_single(error_message, *begin);
}
class when_all_awaitable
{
private:
result_state_base& m_state;
public:
when_all_awaitable(result_state_base& state) noexcept : m_state(state) {}
bool await_ready() const noexcept { return false; }
bool await_suspend(coroutine_handle<void> coro_handle) noexcept { return m_state.await(coro_handle); }
void await_resume() const noexcept {}
};
template<class result_types>
class when_any_awaitable
{
private:
shared_ptr<when_any_context> m_promise;
result_types& m_results;
template<class type> //这里几个 type 类型应该是 result<type>
static result_state_base* get_at(vector<type>& vector, size_t i) noexcept { return get_state_base(vector[i]); }
template<class type>
static size_t size(const vector<type>& vector) noexcept { return vector.size(); }
template<class... types>
static result_state_base* get_at(tuple<types...>& tuple, size_t i) noexcept { return at(tuple, i); }
template<class... types>
static size_t size(tuple<types...>& tuple) noexcept { return tuple_size_v<std::tuple<types...>>; }
public:
when_any_awaitable(result_types& results) noexcept : m_results(results) {}
bool await_ready() const noexcept { return false; }
void await_suspend(coroutine_handle<void> coro_handle) //返回 void,一定挂起
{
m_promise = make_shared<when_any_context>(coro_handle);
const auto range_length = size(m_results);
for (size_t i = 0; i < range_length; i++) {
if (m_promise->fulfilled()) return; //有一个达成了,就返回
auto state_ptr = get_at(m_results, i);
const auto status = state_ptr->when_any(m_promise); //result_state_base::when_any,会设置 consumer_set
if (status == result_state_base::pc_state::producer_done) {
m_promise->try_resume(state_ptr);
return; //完成生产,尝试恢复后返回
}
}
}
size_t await_resume() noexcept
{
#undef max
const auto completed_result_state = m_promise->completed_result();
auto completed_result_index = numeric_limits<size_t>::max();
const auto range_length = size(m_results);
for (size_t i = 0; i < range_length; i++) {
auto state_ptr = get_at(m_results, i);
state_ptr->try_rewind_consumer(); //尝试将所有 consumer_set 倒回 idle,注意下一行没有 break
if (completed_result_state == state_ptr) completed_result_index = i; //找到完成的
}
assert(completed_result_index != numeric_limits<size_t>::max());
return completed_result_index; //返回的是完成者的索引号
}
}; //class when_any_awaitable
}; //class when_result_helper
template<class... result_types>
result<tuple<typename decay<result_types>::type...>> when_all_impl(result_types&&... results)
{
tuple<typename decay<result_types>::type...> tuple = make_tuple(forward<result_types>(results)...);
for (size_t i = 0; i < tuple_size_v<decltype(tuple)>; i++) {
auto state_ptr = when_result_helper::at(tuple, i);
co_await when_result_helper::when_all_awaitable{ *state_ptr };
}
co_return move(tuple);
}
template<class iterator_type>
result<vector<typename iterator_traits<iterator_type>::value_type>> when_all_impl(iterator_type begin, iterator_type end)
{
using type = typename iterator_traits<iterator_type>::value_type;
if (begin == end)
co_return vector<type> {};
vector<type> vector{ make_move_iterator(begin), make_move_iterator(end) };
for (auto& result : vector)
result = co_await result.resolve();
co_return move(vector);
}
template<class... result_types>
result<when_any_result<tuple<result_types...>>> when_any_impl(result_types&&... results)
{
using tuple_type = tuple<result_types...>;
tuple_type tuple = make_tuple(forward<result_types>(results)...);
const auto completed_index = co_await when_result_helper::when_any_awaitable<tuple_type> {tuple};
co_return when_any_result<tuple_type> {completed_index, move(tuple)};
}
template<class iterator_type>
result<when_any_result<vector<typename iterator_traits<iterator_type>::value_type>>>
when_any_impl(iterator_type begin, iterator_type end)
{
using type = typename iterator_traits<iterator_type>::value_type;
vector<type> vector{ make_move_iterator(begin), make_move_iterator(end) };
const auto completed_index = co_await when_result_helper::when_any_awaitable{ vector };
co_return when_any_result<std::vector<type>> {completed_index, move(vector)};
}
}; //namespace details
template<class type, class... argument_types>
result<type> make_ready_result(argument_types&&... arguments)
{
static_assert(is_constructible_v<type, argument_types...> || is_same_v<type, void>,
"concurrencpp::make_ready_result - <<type>> is not constructible from <<argument_types...>");
static_assert(is_same_v<type, void> ? (sizeof...(argument_types) == 0) : true,
"concurrencpp::make_ready_result<void> - this overload does not accept any argument.");
details::producer_result_state_ptr<type> promise(new details::result_state<type>());
details::consumer_result_state_ptr<type> state_ptr(promise.get()); //保证生产者和消费者基于同一个 state
promise->set_result(forward<argument_types>(arguments)...);
promise.reset(); //TODO:跟踪验证一下
return { move(state_ptr) };
}
template<class type>
result<type> make_exceptional_result(exception_ptr exception_ptr)
{
if (!static_cast<bool>(exception_ptr))
throw invalid_argument("make_exception_result() - given exception_ptr is null.");
details::producer_result_state_ptr<type> promise(new details::result_state<type>());
details::consumer_result_state_ptr<type> state_ptr(promise.get());
promise->set_exception(exception_ptr);
promise.reset();
return { move(state_ptr) };
}
template<class type, class exception_type>
result<type> make_exceptional_result(exception_type exception)
{
return make_exceptional_result<type>(make_exception_ptr(exception));
}
template<class... result_types>
result<tuple<typename decay<result_types>::type...>> when_all(result_types&&... results)
{
details::when_result_helper::throw_if_empty_tuple("when_all() - one of the result objects is empty.",
forward<result_types>(results)...);
return details::when_all_impl(forward<result_types>(results)...);
}
template<class iterator_type>
result<vector<typename iterator_traits<iterator_type>::value_type>> when_all(iterator_type begin, iterator_type end)
{
details::when_result_helper::throw_if_empty_range("when_all() - one of the result objects is empty.", begin, end);
return details::when_all_impl(begin, end);
}
inline result<tuple<>> when_all() { return make_ready_result<tuple<>>(); }
template<class... result_types>
result<when_any_result<tuple<result_types...>>> when_any(result_types&&... results)
{
static_assert(sizeof...(result_types) != 0, "when_any() - must accept at least one result object.");
details::when_result_helper::throw_if_empty_tuple(".", forward<result_types>(results)...);
return details::when_any_impl(forward<result_types>(results)...);
}
template<class iterator_type>
result<when_any_result<vector<typename iterator_traits<iterator_type>::value_type>>>
when_any(iterator_type begin, iterator_type end)
{
details::when_result_helper::throw_if_empty_range(".", begin, end);
if (begin == end) throw invalid_argument("when_any() - given range contains no elements.");
return details::when_any_impl(begin, end);
}
namespace details
{
uintptr_t generate_thread_id() noexcept
{
static atomic_uintptr_t s_id_seed = 1;
//默认 memory_order_seq_cst,线程 ID 只要确保唯一即可,故用 relaxed 提高效率
return s_id_seed.fetch_add(1, memory_order_relaxed);
}
struct thread_per_thread_data
{
const uintptr_t id = generate_thread_id();
};
static thread_local thread_per_thread_data s_tl_thread_per_data; //每个线程独有的虚拟 ID 号
class thread
{
private:
std::thread m_thread;
static void set_name(string_view name) noexcept
{
const wstring utf16_name(name.begin(), name.end()); //注意转码用法
::SetThreadDescription(::GetCurrentThread(), utf16_name.data());
}
public:
thread() noexcept = default;
thread(thread&&) noexcept = default;
template<class callable_type>
thread(string name, callable_type&& callable)
{
m_thread = std::thread([name = move(name), callable = forward<callable_type>(callable)]() mutable {
set_name(name);
callable();
});
}
thread& operator=(thread&& rhs) noexcept = default;
std::thread::id get_id() const noexcept { return m_thread.get_id(); }
static uintptr_t get_current_virtual_id() noexcept { return s_tl_thread_per_data.id; }
bool joinable() const noexcept { return m_thread.joinable(); }
void join() { m_thread.join(); }
static size_t hardware_concurrency() noexcept
{
const auto hc = std::thread::hardware_concurrency();
return (hc != 0) ? hc : 8; //consts::k_default_number_of_cores:默认八核
}
};
size_t default_max_cpu_workers() noexcept
{
return static_cast<size_t>(thread::hardware_concurrency() * 1); //consts::k_cpu_threadpool_worker_count_factor
}
size_t default_max_background_workers() noexcept
{
return static_cast<size_t>(thread::hardware_concurrency() * 4); //consts::k_background_threadpool_worker_count_factor
}
} //namespace details
class inline_executor final : public executor //任务立即顺序执行,无并行
{
private:
atomic_bool m_abort;
void throw_if_aborted() const
{
if (m_abort.load(memory_order_relaxed)) details::throw_runtime_shutdown_exception(name);
}
public:
inline_executor() noexcept : executor("inline_executor"), m_abort(false) {}
void enqueue(task task) override { throw_if_aborted(); task(); }
void enqueue(span<task> tasks) override { throw_if_aborted(); for (auto& task : tasks) task(); }
int max_concurrency_level() const noexcept override { return 0; } //details::consts::k_inline_executor_max_concurrency_level
void shutdown() noexcept override { m_abort.store(true, memory_order_relaxed); }
bool shutdown_requested() const noexcept override { return m_abort.load(memory_order_relaxed); }
}; //class inline_executor
template<class concrete_executor_type>
class derivable_executor : public executor //使用 CRTP 技术提升性能,便于编译器优化。此类中用 self() 获得子类指针
{
private:
concrete_executor_type& self() noexcept { return *static_cast<concrete_executor_type*>(this); }
public:
derivable_executor(string_view name) : executor(name) {}
template<class callable_type, class... argument_types>
void post(callable_type&& callable, argument_types&&... arguments)
{
return do_post(self(), forward<callable_type>(callable), forward<argument_types>(arguments)...);
}
template<class callable_type, class... argument_types>
auto submit(callable_type&& callable, argument_types&&... arguments)
{
return do_submit(self(), forward<callable_type>(callable), forward<argument_types>(arguments)...);
}
template<class callable_type>
void bulk_post(span<callable_type> callable_list) { return do_bulk_post(self(), callable_list); }
template<class callable_type, class return_type = invoke_result_t<callable_type>>
vector<result<return_type>> bulk_submit(span<callable_type> callable_list)
{
return do_bulk_submit(self(), callable_list);
}
}; //class derivable_executor
class thread_pool_executor;
namespace details
{
class thread_pool_worker;
struct thread_pool_per_thread_data
{
thread_pool_worker* this_worker;
size_t this_thread_index;
const size_t this_thread_hashed_id;
thread_pool_per_thread_data() noexcept : this_worker(nullptr),
this_thread_index(static_cast<size_t>(-1)), this_thread_hashed_id(calculate_hashed_id())
{
}
static size_t calculate_hashed_id() noexcept
{
const auto this_thread_id = thread::get_current_virtual_id();
hash<size_t> hash;
return hash(this_thread_id);
}
};
static thread_local thread_pool_per_thread_data s_tl_thread_pool_data;
class idle_worker_set
{
enum class status { active, idle };
struct alignas(64) padded_flag { atomic<status> flag{ status::active }; }; //默认活动状态
private:
atomic_intptr_t m_approx_size; //空闲数量
const unique_ptr<padded_flag[]> m_idle_flags; //数组大小等于 size,在构造时确保
const size_t m_size;
bool try_acquire_flag(size_t index) noexcept
{
const auto worker_status = m_idle_flags[index].flag.load(memory_order_relaxed);
if (worker_status == status::active) return false;
const auto before = m_idle_flags[index].flag.exchange(status::active, memory_order_relaxed);
const auto swapped = (before == status::idle); //有可能在上一个 return 之后转为 active,所以还得判断一下
if (swapped) m_approx_size.fetch_sub(1, memory_order_relaxed);
return swapped;
}
public:
idle_worker_set(size_t size) : m_approx_size(0), m_idle_flags(make_unique<padded_flag[]>(size)), m_size(size) {}
void set_idle(size_t idle_thread) noexcept
{
const auto before = m_idle_flags[idle_thread].flag.exchange(status::idle, memory_order_relaxed);
if (before == status::idle) return;
m_approx_size.fetch_add(1, memory_order_release);
}
void set_active(size_t idle_thread) noexcept
{
const auto before = m_idle_flags[idle_thread].flag.exchange(status::active, memory_order_relaxed);
if (before == status::active) return;
m_approx_size.fetch_sub(1, memory_order_release);
}
size_t find_idle_worker(size_t caller_index) noexcept
{
if (m_approx_size.load(memory_order_relaxed) <= 0) return static_cast<size_t>(-1);
const auto starting_pos = (caller_index != static_cast<size_t>(-1)) ? caller_index
: (s_tl_thread_pool_data.this_thread_hashed_id % m_size); //根据哈希值计算一个特定的位置
for (size_t i = 0; i < m_size; i++) {
const auto index = (starting_pos + i) % m_size;
if (index == caller_index) continue;
if (try_acquire_flag(index)) return index; //找到一个空闲的就返回
}
return static_cast<size_t>(-1);
}
void find_idle_workers(size_t caller_index, vector<size_t>& result_buffer, size_t max_count) noexcept
{
assert(result_buffer.capacity() >= max_count); //后面使用的时候,已经确保了它一定是空的
const auto approx_size = m_approx_size.load(memory_order_relaxed);
if (approx_size <= 0) return;
assert(caller_index >= 0 && caller_index < m_size);
assert(caller_index == s_tl_thread_pool_data.this_thread_index); //确保调用者跟自己在同一个线程
size_t count = 0;
#undef min
const auto max_waiters = min(static_cast<size_t>(approx_size), max_count); //最多可用数
for (size_t i = 0; (i < m_size) && (count < max_waiters); i++) {
const auto index = (caller_index + i) % m_size;
if (index == caller_index) continue; //其实就排除第一个吧
if (try_acquire_flag(index)) { result_buffer.emplace_back(index); ++count; }
}
}
}; //class idle_worker_set
class executor_collection
{
private:
mutex m_lock;
vector<shared_ptr<executor>> m_executors;
public:
void register_executor(shared_ptr<executor> executor)
{
assert(static_cast<bool>(executor));
unique_lock<decltype(m_lock)> lock(m_lock); //类型已知,可以不用 decltype
assert(find(m_executors.begin(), m_executors.end(), executor) == m_executors.end());
m_executors.emplace_back(move(executor));
}
void shutdown_all() noexcept
{
unique_lock<decltype(m_lock)> lock(m_lock);
for (auto& executor : m_executors) { assert(static_cast<bool>(executor)); executor->shutdown(); }
m_executors = {}; //共享指针会自动销毁
}
}; //class executor_collection
class alignas(64) thread_pool_worker
{
private:
deque<task> m_private_queue; //本地任务队列
vector<size_t> m_idle_worker_list; //池中空闲工人索引列表
atomic_bool m_atomic_abort;
thread_pool_executor& m_parent_pool;
const size_t m_index; //当前工人在线程池中的序号
const size_t m_pool_size; //线程池大小
const chrono::milliseconds m_max_idle_time;
const string m_worker_name; //线程池执行者名称 + worker
alignas(64) mutex m_lock;
deque<task> m_public_queue; //外来任务队列
binary_semaphore m_semaphore; //二元信号量,0或1,指示有任务到来
bool m_idle; //默认 true,修改此状态变量需要加锁
bool m_abort; //修改也需要加锁
atomic_bool m_event_found; //任务已接收
thread m_thread;
void balance_work();
bool wait_for_task(unique_lock<mutex>& lock) noexcept;
bool drain_queue_impl()
{
auto aborted = false;
while (!m_private_queue.empty()) { //只要本地还有任务
balance_work(); //先做平衡
if (m_atomic_abort.load(memory_order_relaxed)) { aborted = true; break; }
assert(!m_private_queue.empty()); //本地至少还会留一项任务
auto&& task = move(m_private_queue.back()); //最后一项拿出来执行
m_private_queue.pop_back();
task();
}
if (aborted) {
unique_lock<mutex> lock(m_lock);
m_idle = true;
return false;
}
return true;
}
bool drain_queue() //等外来任务,放进私有队列后依次执行
{
unique_lock<mutex> lock(m_lock);
if (!wait_for_task(lock)) return false; //没等到任务,或者要退出了
assert(lock.owns_lock());
assert(!m_public_queue.empty() || m_abort);
m_event_found.store(false, memory_order_relaxed);
if (m_abort) { m_idle = true; return false; } //如果退出,置己为空闲
assert(m_private_queue.empty()); //私有任务一定是空的,要不然不会抽干队列
swap(m_private_queue, m_public_queue); //直接交换,效率最高
lock.unlock();
return drain_queue_impl();
}
void work_loop() noexcept //填充线程变量,启动执行抽干
{
s_tl_thread_pool_data.this_worker = this;
s_tl_thread_pool_data.this_thread_index = m_index;
while (true) { if (!drain_queue()) return; }
}
void ensure_worker_active(bool first_enqueuer, unique_lock<mutex>& lock)
{
assert(lock.owns_lock());
if (!m_idle) { //已在活动中,解锁,若首次入队则释放信号量,随后返回
lock.unlock();
if (first_enqueuer) m_semaphore.release();
return;
}
auto&& stale_worker = move(m_thread); //原有线程腾出来
m_thread = thread(m_worker_name, [this] { work_loop(); }); //启动新线程
m_idle = false;
lock.unlock();
if (stale_worker.joinable()) stale_worker.join(); //等待原有线程结束
}
public:
thread_pool_worker(thread_pool_executor& parent_pool, size_t index, size_t pool_size, chrono::milliseconds max_idle_time);
thread_pool_worker(thread_pool_worker&& rhs) noexcept :
m_parent_pool(rhs.m_parent_pool), m_index(rhs.m_index), m_pool_size(rhs.m_pool_size), m_max_idle_time(rhs.m_max_idle_time),
m_semaphore(0), m_idle(true), m_abort(true)
{
abort();
} //为什么不直接用删除声明?
~thread_pool_worker() noexcept { assert(m_idle); assert(!m_thread.joinable()); } //仅当空闲及线程非活动时
void enqueue_foreign(task& task);
void enqueue_foreign(span<task> tasks);
void enqueue_foreign(deque<task>::iterator begin, deque<task>::iterator end);
void enqueue_foreign(span<task>::iterator begin, span<task>::iterator end);
void enqueue_local(task& task);
void enqueue_local(span<task> tasks);
void shutdown() noexcept
{
assert(m_atomic_abort.load(memory_order_relaxed) == false);
m_atomic_abort.store(true, memory_order_relaxed);
{
unique_lock<mutex> lock(m_lock);
m_abort = true;
}
m_event_found.store(true, memory_order_release); //确保保存完毕后,再通知到工人
m_semaphore.release();
if (m_thread.joinable()) m_thread.join();
decltype(m_public_queue) public_queue; //临时导出的两个变量,以缩短因队列变化导致加锁的时间
decltype(m_private_queue) private_queue;
{
unique_lock<mutex> lock(m_lock); //只对 move 加锁,减少资源占用
public_queue = move(m_public_queue);
private_queue = move(m_private_queue);
}
public_queue.clear();
private_queue.clear();
}
chrono::milliseconds max_worker_idle_time() const noexcept { return m_max_idle_time; }
bool appears_empty() const noexcept
{
return m_private_queue.empty() && !m_event_found.load(memory_order_relaxed);
}
}; //class thread_pool_worker
} //namespace details
//适合非阻塞的短小任务,会自动进行线程注入和负载平衡。除默认类外,背景类适合那些小阻塞任务,如数据库和文件读写
class alignas(64) thread_pool_executor final : public derivable_executor<thread_pool_executor>
{
friend class details::thread_pool_worker;
private:
vector<details::thread_pool_worker> m_workers;
alignas(64) atomic_size_t m_round_robin_cursor; //轮询调度游标
alignas(64) details::idle_worker_set m_idle_workers;
alignas(64) atomic_bool m_abort;
void mark_worker_idle(size_t index) noexcept { assert(index < m_workers.size()); m_idle_workers.set_idle(index); }
void mark_worker_active(size_t index) noexcept { assert(index < m_workers.size()); m_idle_workers.set_active(index); }
details::thread_pool_worker& worker_at(size_t index) noexcept { assert(index <= m_workers.size()); return m_workers[index]; }
void find_idle_workers(size_t caller_index, vector<size_t>& buffer, size_t max_count) noexcept
{
m_idle_workers.find_idle_workers(caller_index, buffer, max_count); //找到的索引放到 buffer
}
public:
thread_pool_executor(string_view pool_name, size_t pool_size, chrono::milliseconds max_idle_time) :
derivable_executor<thread_pool_executor>(pool_name), m_round_robin_cursor(0), m_idle_workers(pool_size), m_abort(false)
{
m_workers.reserve(pool_size); //按线程池大小构建工人,分配好序号,再全部置为空闲
for (size_t i = 0; i < pool_size; i++) m_workers.emplace_back(*this, i, pool_size, max_idle_time);
for (size_t i = 0; i < pool_size; i++) m_idle_workers.set_idle(i);
}
void enqueue(task task) override
{
const auto this_worker = details::s_tl_thread_pool_data.this_worker; //工人进入工作循环后,这两个被赋值
const auto this_worker_index = details::s_tl_thread_pool_data.this_thread_index;
if (this_worker != nullptr && this_worker->appears_empty()) //本线程有工人且闲置,交给它
return this_worker->enqueue_local(task);
const auto idle_worker_pos = m_idle_workers.find_idle_worker(this_worker_index); //否则找一个闲置工人
if (idle_worker_pos != static_cast<size_t>(-1))
return m_workers[idle_worker_pos].enqueue_foreign(task); //交给找到的闲置工人
if (this_worker != nullptr) //所有工人都在忙,但本线程有工人,仍然交给它
return this_worker->enqueue_local(task);
const auto next_worker = m_round_robin_cursor.fetch_add(1, memory_order_relaxed) % m_workers.size();
m_workers[next_worker].enqueue_foreign(task); //本线程无工人,轮询调度下一个工人,交给它
}
void enqueue(span<task> tasks) override
{
if (details::s_tl_thread_pool_data.this_worker != nullptr) //本线程有工人,全交给它入队
return details::s_tl_thread_pool_data.this_worker->enqueue_local(tasks);
//任务数量少于工人数量,就去找闲置工人,或者轮询入队
if (tasks.size() < m_workers.size()) { for (auto& task : tasks) enqueue(move(task)); return; }
const auto task_count = tasks.size();
const auto total_worker_count = m_workers.size();
const auto donation_count = task_count / total_worker_count; //每个工人至少几项任务
auto extra = task_count - donation_count * total_worker_count; //没有平均分完的还有几项
size_t begin = 0;
size_t end = donation_count;
for (size_t i = 0; i < total_worker_count; i++) {
assert(begin < task_count);
if (extra != 0) { end++; extra--; } //还有多余任务,当前工人多分派一项
assert(end <= task_count);
auto tasks_begin_it = tasks.begin() + begin;
auto tasks_end_it = tasks.begin() + end;
assert(tasks_begin_it < tasks.end());
assert(tasks_end_it <= tasks.end());
m_workers[i].enqueue_foreign(tasks_begin_it, tasks_end_it); //划好分片交给当前工人
begin = end;
end += donation_count;
}
}
bool shutdown_requested() const noexcept override { return m_abort.load(memory_order_relaxed); }
void shutdown() noexcept override
{
const auto abort = m_abort.exchange(true, memory_order_relaxed);
if (abort) return; //已经关闭过了,直接返回。否则遍历所有工人逐一关闭
for (auto& worker : m_workers) worker.shutdown();
}
int max_concurrency_level() const noexcept override { return static_cast<int>(m_workers.size()); }
chrono::milliseconds max_worker_idle_time() const noexcept { return m_workers[0].max_worker_idle_time(); }
}; //class thread_pool_executor
namespace details
{
thread_pool_worker::thread_pool_worker(thread_pool_executor& parent_pool, size_t index, size_t pool_size, chrono::milliseconds max_idle_time) :
m_atomic_abort(false), m_parent_pool(parent_pool), m_index(index), m_pool_size(pool_size), m_max_idle_time(max_idle_time),
m_worker_name(details::make_executor_worker_name(parent_pool.name)), m_semaphore(0), m_idle(true), m_abort(false), m_event_found(false)
{
m_idle_worker_list.reserve(pool_size);
}
void thread_pool_worker::balance_work()
{
const auto task_count = m_private_queue.size();
if (task_count < 2) return; //最多就一项任务,没必要平衡
const auto max_idle_worker_count = min(m_pool_size - 1, task_count - 1); //假定其他线程都空闲,给自己也留一项任务
if (max_idle_worker_count == 0) return; //单线程
m_parent_pool.find_idle_workers(m_index, m_idle_worker_list, max_idle_worker_count); //空闲列表只在此函数内使用
const auto idle_count = m_idle_worker_list.size();
if (idle_count == 0) return;
assert(idle_count <= task_count); //TODO:这个断言怎么来的?
const auto total_worker_count = (idle_count + 1); //工人总数,包括自己。不然所有任务都会捐出去
const auto donation_count = task_count / total_worker_count; //每个工人平均任务数
auto extra = task_count - donation_count * total_worker_count; //剩下的余数
size_t begin = 0;
size_t end = donation_count;
for (const auto idle_worker_index : m_idle_worker_list) { //这里跟 executor 一样的分配策略
assert(idle_worker_index != m_index);
assert(idle_worker_index < m_pool_size);
assert(begin < task_count);
if (extra != 0) { end++; extra--; }
assert(end <= task_count);
auto donation_begin_it = m_private_queue.begin() + begin;
auto donation_end_it = m_private_queue.begin() + end;
assert(donation_begin_it < m_private_queue.end());
assert(donation_end_it <= m_private_queue.end());
m_parent_pool.worker_at(idle_worker_index).enqueue_foreign(donation_begin_it, donation_end_it);
begin = end;
end += donation_count;
}
assert(m_private_queue.size() == task_count); //TODO:此断言的必要性是?
//一直到 +begin 是已经捐出去的部分,从私有队列中删除。剩下的是留给自己的任务
assert(all_of(m_private_queue.begin(), m_private_queue.begin() + begin, [](auto& task) { return !static_cast<bool>(task); }));
assert(all_of(m_private_queue.begin() + begin, m_private_queue.end(), [](auto& task) { return static_cast<bool>(task); }));
m_private_queue.erase(m_private_queue.begin(), m_private_queue.begin() + begin);
assert(!m_private_queue.empty());
m_idle_worker_list.clear(); //清空空闲列表,下次查找时一定是空的
}
bool thread_pool_worker::wait_for_task(unique_lock<mutex>& lock) noexcept
{
assert(lock.owns_lock()); //此函数只在抽干队列时调用,预先已加锁
if (!m_public_queue.empty() || m_abort) return true; //有外来任务或者中断了
lock.unlock(); //先解锁,外来任务队列和退出变量暂时不用了
m_parent_pool.mark_worker_idle(m_index);
auto event_found = false;
const auto deadline = chrono::steady_clock::now() + m_max_idle_time;
while (true) {
if (!m_semaphore.try_acquire_until(deadline)) { //没等到 release,时间没到就继续等,否则结束等待
if (chrono::steady_clock::now() <= deadline) continue;
else break;
}
if (!m_event_found.load(memory_order_relaxed)) continue; //没发现事件,继续等
lock.lock(); //要访问外来队列和退出标志了,加锁
if (m_public_queue.empty() && !m_abort) { lock.unlock(); continue; } //其他任务是空的,也没中断,就继续等
event_found = true;
break;
}
if (!lock.owns_lock()) lock.lock();
if (!event_found || m_abort) { m_idle = true; lock.unlock(); return false; } //没发现事件,或者中断了
assert(!m_public_queue.empty());
m_parent_pool.mark_worker_active(m_index);
return true;
}
void thread_pool_worker::enqueue_foreign(task& task)
{
unique_lock<mutex> lock(m_lock);
if (m_abort) throw_runtime_shutdown_exception(m_parent_pool.name);
m_event_found.store(true, memory_order_relaxed); //其他工人派任务来了
const auto is_empty = m_public_queue.empty();
m_public_queue.emplace_back(move(task)); //任务压到公用队列
ensure_worker_active(is_empty, lock); //首参数为是否队列中第一项任务
}
void thread_pool_worker::enqueue_foreign(span<task> tasks)
{
unique_lock<mutex> lock(m_lock);
if (m_abort) throw_runtime_shutdown_exception(m_parent_pool.name);
m_event_found.store(true, memory_order_relaxed);
const auto is_empty = m_public_queue.empty();
m_public_queue.insert(m_public_queue.end(), make_move_iterator(tasks.begin()), make_move_iterator(tasks.end()));
ensure_worker_active(is_empty, lock); //注意上面的 make_move_iterator,高效地完成批量移动式插入
}
void thread_pool_worker::enqueue_foreign(deque<task>::iterator begin, deque<task>::iterator end)
{
unique_lock<mutex> lock(m_lock);
if (m_abort) throw_runtime_shutdown_exception(m_parent_pool.name);
m_event_found.store(true, memory_order_relaxed);
const auto is_empty = m_public_queue.empty();
m_public_queue.insert(m_public_queue.end(), make_move_iterator(begin), make_move_iterator(end));
ensure_worker_active(is_empty, lock);
}
void thread_pool_worker::enqueue_foreign(span<task>::iterator begin, span<task>::iterator end)
{
unique_lock<mutex> lock(m_lock);
if (m_abort) throw_runtime_shutdown_exception(m_parent_pool.name);
m_event_found.store(true, memory_order_relaxed);
const auto is_empty = m_public_queue.empty();
m_public_queue.insert(m_public_queue.end(), make_move_iterator(begin), make_move_iterator(end));
ensure_worker_active(is_empty, lock);
}
void thread_pool_worker::enqueue_local(task& task) //自己给自己派的任务
{
if (m_atomic_abort.load(memory_order_relaxed)) throw_runtime_shutdown_exception(m_parent_pool.name);
m_private_queue.emplace_back(move(task));
}
void thread_pool_worker::enqueue_local(span<task> tasks)
{
if (m_atomic_abort.load(memory_order_relaxed)) throw_runtime_shutdown_exception(m_parent_pool.name);
m_private_queue.insert(m_private_queue.end(), make_move_iterator(tasks.begin()), make_move_iterator(tasks.end()));
}
} //namepsace details
//每个任务开一个线程,线程不可重用。适合长期运行的任务,比如工作循环,或者长时间阻塞任务
class alignas(64) thread_executor final : public derivable_executor<thread_executor>
{
private:
mutex m_lock;
list<details::thread> m_workers;
condition_variable m_condition;
list<details::thread> m_last_retired;
bool m_abort;
atomic_bool m_atomic_abort;
void enqueue_impl(unique_lock<mutex>& lock, task& task)
{
assert(lock.owns_lock()); //enqueue 一定会先加锁
auto& new_thread = m_workers.emplace_front(); //先从前面入列(可通过 begin 访问),下面再赋值启动
new_thread = details::thread(details::make_executor_worker_name(name),
[this, self_it = m_workers.begin(), task = move(task)]() mutable { task(); retire_worker(self_it); });
}
void retire_worker(list<details::thread>::iterator it) //it 指向的就是新入列的任务
{
unique_lock<decltype(m_lock)> lock(m_lock);
auto&& last_retired = move(m_last_retired); //全部移到临时列表,应该也至多只有一个
m_last_retired.splice(m_last_retired.begin(), m_workers, it); //新来的任务放进队列
lock.unlock();
m_condition.notify_one(); //workers 少了一个,所以要通知一次
if (last_retired.empty()) return;
assert(last_retired.size() == 1);
last_retired.front().join(); //原来的任务 join
}
public:
thread_executor() : derivable_executor<thread_executor>("thread_executor"), m_abort(false), m_atomic_abort(false) {}
~thread_executor() noexcept { assert(m_workers.empty()); assert(m_last_retired.empty()); }
void enqueue(task task) override
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
enqueue_impl(lock, task);
}
void enqueue(span<task> tasks) override
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
for (auto& task : tasks) enqueue_impl(lock, task); //加一个任务,就开启一个线程
}
int max_concurrency_level() const noexcept override
{
#undef max
return numeric_limits<int>::max(); //details::consts::k_thread_executor_max_concurrency_level
}
bool shutdown_requested() const noexcept override { return m_atomic_abort.load(memory_order_relaxed); }
void shutdown() noexcept override
{
const auto abort = m_atomic_abort.exchange(true, memory_order_relaxed);
if (abort) return; //调用过 shutdown 了
unique_lock<decltype(m_lock)> lock(m_lock);
m_abort = true;
m_condition.wait(lock, [this] { return m_workers.empty(); }); //如果任务队列还没结束,就等待
if (m_last_retired.empty()) return;
assert(m_last_retired.size() == 1); //最后一项当前任务 join 后就全部关闭
m_last_retired.front().join();
m_last_retired.clear();
}
}; //class thread_executor
class worker_thread_executor;
namespace details { static thread_local worker_thread_executor* s_tl_this_worker = nullptr; }
//单任务单线程,适合执行多项关联任务的复杂线程
class alignas(64) worker_thread_executor final : public derivable_executor<worker_thread_executor>
{
private:
deque<task> m_private_queue;
atomic_bool m_private_atomic_abort;
details::thread m_thread;
alignas(64) mutex m_lock;
deque<task> m_public_queue;
binary_semaphore m_semaphore;
atomic_bool m_atomic_abort;
bool m_abort;
bool drain_queue_impl()
{
while (!m_private_queue.empty()) { //只要私有队列中还有任务
auto&& task = move(m_private_queue.front()); //从队列中抽出来执行
m_private_queue.pop_front();
if (m_private_atomic_abort.load(memory_order_relaxed)) return false;
task();
}
return true;
}
bool drain_queue()
{
unique_lock<decltype(m_lock)> lock(m_lock);
wait_for_task(lock);
assert(lock.owns_lock());
assert(!m_public_queue.empty() || m_abort); //要么外部有任务,要么取消了
if (m_abort) return false;
assert(m_private_queue.empty());
swap(m_private_queue, m_public_queue); //直接交换,效率最高
lock.unlock();
return drain_queue_impl();
}
void wait_for_task(unique_lock<mutex>& lock)
{
assert(lock.owns_lock());
if (!m_public_queue.empty() || m_abort) return; //有外部任务,或者取消
while (true) {
lock.unlock();
m_semaphore.acquire();
lock.lock();
if (!m_public_queue.empty() || m_abort) break;
}
}
void work_loop() noexcept
{
details::s_tl_this_worker = this;
while (true) if (!drain_queue()) return; //不断地等待任务并执行,直到 abort
}
void enqueue_local(task& task)
{
if (m_private_atomic_abort.load(memory_order_relaxed)) details::throw_runtime_shutdown_exception(name);
m_private_queue.emplace_back(move(task));
}
void enqueue_local(span<task> tasks)
{
if (m_private_atomic_abort.load(memory_order_relaxed)) details::throw_runtime_shutdown_exception(name);
m_private_queue.insert(m_private_queue.end(), make_move_iterator(tasks.begin()), make_move_iterator(tasks.end()));
}
void enqueue_foreign(task& task)
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
const auto is_empty = m_public_queue.empty();
m_public_queue.emplace_back(move(task));
lock.unlock();
if (is_empty) m_semaphore.release(); //原本是空的,就要发一个信号,唤醒等待线程
}
void enqueue_foreign(span<task> tasks)
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
const auto is_empty = m_public_queue.empty();
m_public_queue.insert(m_public_queue.end(), make_move_iterator(tasks.begin()), make_move_iterator(tasks.end()));
lock.unlock();
if (is_empty) m_semaphore.release();
}
public:
worker_thread_executor() : derivable_executor<worker_thread_executor>("worker_thread_executor"),
m_private_atomic_abort(false), m_abort(false), m_semaphore(0), m_atomic_abort(false)
{
m_thread = details::thread(details::make_executor_worker_name(name), [this] { work_loop(); });
}
void enqueue(task task) override
{
if (details::s_tl_this_worker == this) return enqueue_local(task);
enqueue_foreign(task);
}
void enqueue(span<task> tasks) override
{
if (details::s_tl_this_worker == this) return enqueue_local(tasks);
enqueue_foreign(tasks);
}
int max_concurrency_level() const noexcept override { return 1; } //details::consts::k_worker_thread_max_concurrency_level
bool shutdown_requested() const noexcept override { return m_atomic_abort.load(memory_order_relaxed); }
void shutdown() noexcept override
{
const auto abort = m_atomic_abort.exchange(true, memory_order_relaxed);
if (abort) return;
{
unique_lock<mutex> lock(m_lock);
m_abort = true;
}
m_semaphore.release();
if (m_thread.joinable()) m_thread.join();
decltype(m_private_queue) private_queue;
decltype(m_public_queue) public_queue;
{
unique_lock<mutex> lock(m_lock);
private_queue = move(m_private_queue);
public_queue = move(m_public_queue);
}
private_queue.clear();
public_queue.clear();
}
}; //class worker_thread_executor
//自己不执行任务,由调用者从外部启动任务
class alignas(64) manual_executor final : public derivable_executor<manual_executor>
{
private:
mutable mutex m_lock;
deque<task> m_tasks;
condition_variable m_condition;
bool m_abort;
atomic_bool m_atomic_abort;
template<class clock_type, class duration_type>
static chrono::system_clock::time_point to_system_time_point(chrono::time_point<clock_type, duration_type> time_point)
{
const auto src_now = clock_type::now();
const auto dst_now = chrono::system_clock::now();
return dst_now + chrono::duration_cast<chrono::milliseconds>(time_point - src_now);
}
static chrono::system_clock::time_point time_point_from_now(chrono::milliseconds ms)
{
return chrono::system_clock::now() + ms;
}
size_t loop_impl(size_t max_count)
{
if (max_count == 0) return 0;
size_t executed = 0;
while (true) {
if (executed == max_count) break;
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) break; //中断循环之后,锁自动解开
if (m_tasks.empty()) break;
auto&& task = move(m_tasks.front());
m_tasks.pop_front();
lock.unlock();
task();
++executed;
}
if (shutdown_requested()) details::throw_runtime_shutdown_exception(name);
return executed;
}
size_t loop_until_impl(size_t max_count, chrono::time_point<chrono::system_clock> deadline)
{
if (max_count == 0) return 0;
size_t executed = 0;
deadline += chrono::milliseconds(1);
while (true) {
if (executed == max_count) break;
const auto now = chrono::system_clock::now();
if (now >= deadline) break; //总数到了,或者终止时间到了,都会中断
unique_lock<decltype(m_lock)> lock(m_lock);
const auto found_task = m_condition.wait_until(lock, deadline, [this] { return !m_tasks.empty() || m_abort; });
if (m_abort) break;
if (!found_task) break;
assert(!m_tasks.empty()); //找到了,就一定不空
auto&& task = move(m_tasks.front());
m_tasks.pop_front();
lock.unlock();
task();
++executed;
}
if (shutdown_requested()) details::throw_runtime_shutdown_exception(name);
return executed;
}
void wait_for_tasks_impl(size_t count) //等够多项任务
{
if (count == 0) {
if (shutdown_requested()) details::throw_runtime_shutdown_exception(name);
return;
}
unique_lock<decltype(m_lock)> lock(m_lock);
m_condition.wait(lock, [this, count] { return (m_tasks.size() >= count) || m_abort; });
if (m_abort) details::throw_runtime_shutdown_exception(name);
assert(m_tasks.size() >= count);
}
size_t wait_for_tasks_impl(size_t count, chrono::time_point<chrono::system_clock> deadline)
{
deadline += chrono::milliseconds(1);
unique_lock<decltype(m_lock)> lock(m_lock);
m_condition.wait_until(lock, deadline, [this, count] { return (m_tasks.size() >= count) || m_abort; });
if (m_abort) details::throw_runtime_shutdown_exception(name);
return m_tasks.size();
}
public:
manual_executor() : derivable_executor<manual_executor>("manual_executor"), m_abort(false), m_atomic_abort(false) {}
void enqueue(task task) override
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
m_tasks.emplace_back(move(task));
lock.unlock();
m_condition.notify_all(); //有任务入队,唤醒等待线程
}
void enqueue(span<task> tasks) override
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
m_tasks.insert(m_tasks.end(), make_move_iterator(tasks.begin()), make_move_iterator(tasks.end()));
lock.unlock();
m_condition.notify_all();
}
int max_concurrency_level() const noexcept override { return numeric_limits<int>::max(); } //k_manual_executor_max_concurrency_level
bool shutdown_requested() const noexcept override { return m_atomic_abort.load(memory_order_relaxed); }
void shutdown() noexcept override
{
const auto abort = m_atomic_abort.exchange(true, memory_order_relaxed);
if (abort) return;
decltype(m_tasks) tasks;
{
unique_lock<decltype(m_lock)> lock(m_lock);
m_abort = true;
tasks = move(m_tasks);
}
m_condition.notify_all();
tasks.clear();
}
size_t size() const noexcept
{
unique_lock<decltype(m_lock)> lock(m_lock);
return m_tasks.size();
}
bool empty() const noexcept { return size() == 0; }
size_t clear()
{
unique_lock<decltype(m_lock)> lock(m_lock);
if (m_abort) details::throw_runtime_shutdown_exception(name);
const auto&& tasks = move(m_tasks);
lock.unlock();
return tasks.size();
}
bool loop_once() { return loop_impl(1) != 0; }
bool loop_once_for(chrono::milliseconds max_waiting_time) //执行结束或超时后返回
{
if (max_waiting_time == chrono::milliseconds(0)) return loop_impl(1) != 0;
return loop_until_impl(1, time_point_from_now(max_waiting_time));
}
template<class clock_type, class duration_type>
bool loop_once_until(chrono::time_point<clock_type, duration_type> timeout_time)
{
return loop_until_impl(1, to_system_time_point(timeout_time));
}
size_t loop(size_t max_count) { return loop_impl(max_count); }
size_t loop_for(size_t max_count, chrono::milliseconds max_waiting_time)
{
if (max_count == 0) return 0;
if (max_waiting_time == chrono::milliseconds(0)) return loop_impl(max_count);
return loop_until_impl(max_count, time_point_from_now(max_waiting_time));
}
template<class clock_type, class duration_type>
size_t loop_until(size_t max_count, chrono::time_point<clock_type, duration_type> timeout_time)
{
return loop_until_impl(max_count, to_system_time_point(timeout_time));
}
void wait_for_task() { wait_for_tasks_impl(1); }
bool wait_for_task_for(chrono::milliseconds max_waiting_time)
{
return wait_for_tasks_impl(1, time_point_from_now(max_waiting_time)) == 1;
}
template<class clock_type, class duration_type>
bool wait_for_task_until(chrono::time_point<clock_type, duration_type> timeout_time)
{
return wait_for_tasks_impl(1, to_system_time_point(timeout_time)) == 1;
}
void wait_for_tasks(size_t count) { wait_for_tasks_impl(count); }
size_t wait_for_tasks_for(size_t count, chrono::milliseconds max_waiting_time)
{
return wait_for_tasks_impl(count, time_point_from_now(max_waiting_time));
}
template<class clock_type, class duration_type>
size_t wait_for_tasks_until(size_t count, chrono::time_point<clock_type, duration_type> timeout_time)
{
return wait_for_tasks_impl(count, to_system_time_point(timeout_time));
}
}; //class manual_executor
namespace details
{
template<class executor_type>
class resume_on_awaitable : public suspend_always
{
private:
await_context m_await_ctx;
executor_type& m_executor;
public:
resume_on_awaitable(executor_type& executor) noexcept : m_executor(executor) {}
resume_on_awaitable(const resume_on_awaitable&) = delete;
resume_on_awaitable(resume_on_awaitable&&) = delete;
resume_on_awaitable& operator=(const resume_on_awaitable&) = delete;
resume_on_awaitable& operator=(resume_on_awaitable&&) = delete;
void await_suspend(coroutine_handle<void> handle)
{
m_await_ctx.set_coro_handle(handle);
try { //协程内部有可能抛出异常
m_executor.template post<await_via_functor>(&m_await_ctx); //换一个执行器去跑恢复代码
} catch (...) { //异常会使队列任务打破并以中断方式恢复。此处无需处理
}
}
void await_resume() const { m_await_ctx.throw_if_interrupted(); }
};
}
template<class executor_type>
auto resume_on(shared_ptr<executor_type> executor)
{
static_assert(is_base_of_v<concurrencpp::executor, executor_type>,
"resume_on() - given executor does not derive from concurrencpp::executor");
if (!static_cast<bool>(executor))
throw invalid_argument("resume_on - given executor is null.");
return details::resume_on_awaitable<executor_type>(*executor); //返回挂起协程的 awaitable,在指定 executor 中恢复
}
template<class executor_type>
auto resume_on(executor_type& executor) noexcept
{
return details::resume_on_awaitable<executor_type>(executor);
}
struct runtime_options
{
size_t max_cpu_threads; //CPU核心数,默认八核
chrono::milliseconds max_thread_pool_executor_waiting_time; //默认120秒
size_t max_background_threads; //CPU核心数的四倍
chrono::milliseconds max_background_executor_waiting_time; //默认120秒
runtime_options() noexcept :
max_cpu_threads(details::default_max_cpu_workers()),
max_thread_pool_executor_waiting_time(chrono::seconds(120)),
max_background_threads(details::default_max_background_workers()),
max_background_executor_waiting_time(chrono::seconds(120))
{
}
runtime_options(const runtime_options&) noexcept = default;
runtime_options& operator=(const runtime_options&) noexcept = default;
}; //struct runtime_options
class runtime
{
private:
shared_ptr<inline_executor> m_inline_executor;
shared_ptr<thread_pool_executor> m_thread_pool_executor;
shared_ptr<thread_pool_executor> m_background_executor;
shared_ptr<thread_executor> m_thread_executor;
details::executor_collection m_registered_executors;
public:
runtime() : runtime(runtime_options()) {}
runtime(const runtime_options& options)
{
m_inline_executor = make_shared<concurrencpp::inline_executor>();
m_registered_executors.register_executor(m_inline_executor);
m_thread_pool_executor = make_shared<concurrencpp::thread_pool_executor>("concurrencpp::thread_pool_executor",
options.max_cpu_threads, options.max_thread_pool_executor_waiting_time);
m_registered_executors.register_executor(m_thread_pool_executor);
m_background_executor = make_shared<concurrencpp::thread_pool_executor>("concurrencpp::background_executor",
options.max_background_threads, options.max_background_executor_waiting_time);
m_registered_executors.register_executor(m_background_executor);
m_thread_executor = make_shared<concurrencpp::thread_executor>();
m_registered_executors.register_executor(m_thread_executor);
}
~runtime() noexcept
{
m_registered_executors.shutdown_all();
}
shared_ptr<inline_executor> inline_executor() const noexcept { return m_inline_executor; }
shared_ptr<thread_pool_executor> background_executor() const noexcept { return m_background_executor; }
shared_ptr<thread_pool_executor> thread_pool_executor() const noexcept { return m_thread_pool_executor; }
shared_ptr<thread_executor> thread_executor() const noexcept { return m_thread_executor; }
shared_ptr<worker_thread_executor> make_worker_thread_executor()
{
auto executor = make_shared<worker_thread_executor>();
m_registered_executors.register_executor(executor);
return executor;
}
shared_ptr<manual_executor> make_manual_executor()
{
auto executor = make_shared<manual_executor>();
m_registered_executors.register_executor(executor);
return executor;
}
//details::consts::k_concurrencpp_version_major, k_concurrencpp_version_minor, k_concurrencpp_version_revision
static tuple<unsigned int, unsigned int, unsigned int> version() noexcept { return { 0, 1, 3 }; }
template<class executor_type, class... argument_types>
shared_ptr<executor_type> make_executor(argument_types&&... arguments)
{
static_assert(
is_base_of_v<executor, executor_type>,
"runtime::make_executor - <<executor_type>> is not a derived class of executor.");
static_assert(is_constructible_v<executor_type, argument_types...>,
"runtime::make_executor - can not build <<executor_type>> from <<argument_types...>>.");
static_assert(!is_abstract_v<executor_type>,
"runtime::make_executor - <<executor_type>> is an abstract class.");
auto executor = make_shared<executor_type>(forward<argument_types>(arguments)...);
m_registered_executors.register_executor(executor);
return executor;
}
}; //class runtime
} //namespace concurrencpp
using namespace concurrencpp;
namespace test01
{
vector<int> make_random_vector()
{
vector<int> vec(64 * 1'024);
srand(uint32_t(time(nullptr)));
for (auto& i : vec) i = ::rand();
return vec;
}
//result<T> 使用 result_coro_promise,它继承自 return_value_struct
//其 return_value 函数,通过成员 result_state 调用 set_result 得到结果
//随后通过 final_suspend 使用 result_publisher,从其 await_suspend 调用 complete_producer
//取值时,co_await 操作符基于 result_state 构造 awaitable
//从其 await_resume 调 state->producer->get 拿到结果。若 co_return 则直接从 get 拿结果
//之后 state 自动释放,期间调用 complete_consumer 对 state 进行自释放
result<size_t> count_even(shared_ptr<thread_pool_executor> tpe, const vector<int>& vector)
{
const auto vecor_size = vector.size();
const auto concurrency_level = tpe->max_concurrency_level();
const auto chunk_size = vecor_size / concurrency_level;
std::vector<result<size_t>> chunk_count;
for (auto i = 0; i < concurrency_level; i++) {
const auto chunk_begin = i * chunk_size;
const auto chunk_end = chunk_begin + chunk_size;
auto result = tpe->submit([&vector, chunk_begin, chunk_end]() -> size_t {
return count_if(vector.begin() + chunk_begin, vector.begin() + chunk_end, [](auto i) {
return i % 2 == 0;
});
});
chunk_count.emplace_back(move(result));
}
size_t total_count = 0;
for (auto& result : chunk_count)
total_count += co_await result;
co_return total_count;
}
void test()
{
runtime runtime;
const auto vector = make_random_vector();
auto result = count_even(runtime.thread_pool_executor(), vector);
const auto total_count = result.get();
cout << "there are " << total_count << " even numbers in the vector" << endl;
}
}
namespace test02
{
void test()
{
result_promise<string> promise; //外部使用一般用它
auto result = promise.get_result(); //类型为 result<string>
std::thread my_3_party_executor([promise = move(promise)]() mutable {
this_thread::sleep_for(chrono::seconds(1));
promise.set_result("hello world");
});
//消费者请求时,生成者还未完成。状态 consumer_set
//result_state_base 里 wait 时,消费者会用 wait_context 等待
//出结果后,complete_producer 会唤醒恢复消费者,取到结果
auto asynchronous_string = result.get();
cout << "result promise returned string: " << asynchronous_string << endl;
my_3_party_executor.join();
}
}
namespace test03
{
result<void> consume_shared_result(shared_result<int> shared_result, shared_ptr<executor> resume_executor)
{
cout << "Awaiting shared_result to have a value" << endl;
const auto& async_value = co_await shared_result;
resume_on(resume_executor); //原本在背景执行器,恢复到线程池执行器
cout << "In thread id " << this_thread::get_id() << ", got: " << async_value
<< ", memory address: " << &async_value << endl;
}
void test()
{
runtime runtime;
//Win 下都在同一个线程中执行,线程间不会重复等待
auto result = runtime.background_executor()->submit([] {
this_thread::sleep_for(chrono::seconds(1));
return 100;
});
shared_result<int> shared_result(move(result));
//shared complete_producer 只执行一次,但 result complete_consumer 会执行八次
concurrencpp::result<void> results[8];
for (size_t i = 0; i < 8; i++)
results[i] = consume_shared_result(shared_result, runtime.thread_pool_executor());
cout << "Main thread waiting for all consumers to finish" << endl;
auto all_consumed = when_all(begin(results), end(results));
all_consumed.get();
cout << "All consumers are done, exiting" << endl;
}
}
namespace test04
{
class logging_executor : public derivable_executor<logging_executor>
{
private:
mutable mutex _lock;
queue<task> _queue;
condition_variable _condition;
bool _shutdown_requested;
std::thread _thread;
const string _prefix;
void work_loop()
{
while (true) {
unique_lock<mutex> lock(_lock);
if (_shutdown_requested) return;
if (!_queue.empty()) {
auto&& task = move(_queue.front());
_queue.pop();
lock.unlock();
cout << _prefix << " A task is being executed" << endl;
task();
continue;
}
_condition.wait(lock, [this] { return !_queue.empty() || _shutdown_requested; });
}
}
public:
logging_executor(string_view prefix) : derivable_executor<logging_executor>("logging_executor"),
_shutdown_requested(false), _prefix(prefix)
{
_thread = std::thread([this] { work_loop(); });
}
void enqueue(task task) override
{
cout << _prefix << " A task is being enqueued!" << endl;
unique_lock<mutex> lock(_lock);
if (_shutdown_requested)
throw errors::runtime_shutdown("logging executor - executor was shutdown.");
_queue.emplace(move(task));
_condition.notify_one();
}
void enqueue(span<task> tasks) override
{
cout << _prefix << tasks.size() << " tasks are being enqueued!" << endl;
unique_lock<mutex> lock(_lock);
if (_shutdown_requested)
throw errors::runtime_shutdown("logging executor - executor was shutdown.");
for (auto& task : tasks)
_queue.emplace(move(task));
_condition.notify_one();
}
int max_concurrency_level() const noexcept override { return 1; }
bool shutdown_requested() const noexcept override
{
unique_lock<mutex> lock(_lock);
return _shutdown_requested;
}
void shutdown() noexcept override
{
cout << _prefix << " shutdown requested" << endl;
unique_lock<mutex> lock(_lock);
if (_shutdown_requested) return; //nothing to do.
_shutdown_requested = true;
lock.unlock();
_condition.notify_one();
_thread.join();
}
};
void test()
{
runtime runtime;
auto logging_ex = runtime.make_executor<logging_executor>("Session #1234");
for (size_t i = 0; i < 10; i++)
logging_ex->post([] { cout << "hello world" << endl; });
}
}
int main()
{
test01::test();
test02::test();
test03::test();
test04::test();
return 0;
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。