mirror of
https://github.com/drogonframework/drogon.git
synced 2025-07-04 00:00:46 -04:00
Compare commits
2 Commits
b52ab5f4f1
...
debdc27c4d
Author | SHA1 | Date | |
---|---|---|---|
|
debdc27c4d | ||
|
3221c43857 |
@ -55,6 +55,10 @@ auto getAwaiter(T &&value) noexcept(
|
||||
return getAwaiterImpl(static_cast<T &&>(value));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using void_to_false_t =
|
||||
std::conditional_t<std::is_same_v<T, void>, std::false_type, T>;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
template <typename T>
|
||||
@ -420,6 +424,11 @@ struct CallbackAwaiter : public trantor::NonCopyable
|
||||
return false;
|
||||
}
|
||||
|
||||
bool hasException() const noexcept
|
||||
{
|
||||
return exception_ != nullptr;
|
||||
}
|
||||
|
||||
const T &await_resume() const noexcept(false)
|
||||
{
|
||||
// await_resume() should always be called after co_await
|
||||
@ -470,6 +479,11 @@ struct CallbackAwaiter<void> : public trantor::NonCopyable
|
||||
std::rethrow_exception(exception_);
|
||||
}
|
||||
|
||||
bool hasException() const noexcept
|
||||
{
|
||||
return exception_ != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::exception_ptr exception_{nullptr};
|
||||
|
||||
@ -798,6 +812,180 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
|
||||
std::function<T()> task_;
|
||||
trantor::EventLoop *loop_;
|
||||
};
|
||||
|
||||
template <typename... Tasks>
|
||||
struct WhenAllAwaiter
|
||||
: public CallbackAwaiter<
|
||||
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...>>
|
||||
{
|
||||
WhenAllAwaiter(Tasks... tasks)
|
||||
: tasks_(std::forward<Tasks>(tasks)...), counter_(sizeof...(tasks))
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (counter_ == 0)
|
||||
{
|
||||
handle.resume();
|
||||
return;
|
||||
}
|
||||
|
||||
await_suspend_impl(handle, std::index_sequence_for<Tasks...>{});
|
||||
}
|
||||
|
||||
private:
|
||||
std::tuple<Tasks...> tasks_;
|
||||
std::atomic<size_t> counter_;
|
||||
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...> results_;
|
||||
std::atomic_flag exceptionFlag_;
|
||||
|
||||
template <size_t Idx>
|
||||
void launch_task(std::coroutine_handle<> handle)
|
||||
{
|
||||
using Self = WhenAllAwaiter<Tasks...>;
|
||||
[](Self *self, std::coroutine_handle<> handle) -> AsyncTask {
|
||||
try
|
||||
{
|
||||
using TaskType = std::tuple_element_t<
|
||||
Idx,
|
||||
std::remove_cvref_t<decltype(results_)>>;
|
||||
if constexpr (std::is_same_v<TaskType, std::false_type>)
|
||||
{
|
||||
co_await std::get<Idx>(self->tasks_);
|
||||
std::get<Idx>(self->results_) = std::false_type{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::get<Idx>(self->results_) =
|
||||
co_await std::get<Idx>(self->tasks_);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (self->exceptionFlag_.test_and_set() == false)
|
||||
self->setException(std::current_exception());
|
||||
}
|
||||
|
||||
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
|
||||
{
|
||||
if (!self->hasException())
|
||||
self->setValue(std::move(self->results_));
|
||||
handle.resume();
|
||||
}
|
||||
}(this, handle);
|
||||
}
|
||||
|
||||
template <size_t... Is>
|
||||
void await_suspend_impl(std::coroutine_handle<> handle,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
((launch_task<Is>(handle)), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct WhenAllAwaiter<std::vector<Task<T>>>
|
||||
: public CallbackAwaiter<std::vector<T>>
|
||||
{
|
||||
WhenAllAwaiter(std::vector<Task<T>> tasks)
|
||||
: tasks_(std::move(tasks)),
|
||||
counter_(tasks_.size()),
|
||||
results_(tasks_.size())
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (tasks_.empty())
|
||||
{
|
||||
this->setValue(std::vector<T>{});
|
||||
handle.resume();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t count = tasks_.size();
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
[](WhenAllAwaiter *self,
|
||||
std::coroutine_handle<> handle,
|
||||
Task<T> task,
|
||||
size_t index) -> AsyncTask {
|
||||
try
|
||||
{
|
||||
auto result = co_await task;
|
||||
self->results_[index] = std::move(result);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (self->exceptionFlag_.test_and_set() == false)
|
||||
self->setException(std::current_exception());
|
||||
}
|
||||
|
||||
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
|
||||
{
|
||||
if (!self->hasException())
|
||||
{
|
||||
self->setValue(std::move(self->results_));
|
||||
}
|
||||
handle.resume();
|
||||
}
|
||||
}(this, handle, std::move(tasks_[i]), i);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Task<T>> tasks_;
|
||||
std::atomic<size_t> counter_;
|
||||
std::vector<T> results_;
|
||||
std::atomic_flag exceptionFlag_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WhenAllAwaiter<std::vector<Task<void>>> : public CallbackAwaiter<void>
|
||||
{
|
||||
WhenAllAwaiter(std::vector<Task<void>> &&t)
|
||||
: tasks_(std::move(t)), counter_(tasks_.size())
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (tasks_.empty())
|
||||
{
|
||||
handle.resume();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t count =
|
||||
tasks_
|
||||
.size(); // capture the size fist (see lifetime comment beflow)
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
[](WhenAllAwaiter *self,
|
||||
std::coroutine_handle<> handle,
|
||||
Task<> task) -> AsyncTask {
|
||||
try
|
||||
{
|
||||
co_await task;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (self->exceptionFlag_.test_and_set() == false)
|
||||
self->setException(std::current_exception());
|
||||
}
|
||||
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
|
||||
// This line CAN delete `this` at last iteration. We MUST
|
||||
// NOT depend on this after last iteration
|
||||
handle.resume();
|
||||
}(this, handle, std::move(tasks_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Task<void>> tasks_;
|
||||
std::atomic<size_t> counter_;
|
||||
std::atomic_flag exceptionFlag_;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
/**
|
||||
@ -987,4 +1175,23 @@ class Mutex final
|
||||
CoroMutexAwaiter *waiters_;
|
||||
};
|
||||
|
||||
template <typename... Tasks>
|
||||
internal::WhenAllAwaiter<Tasks...> when_all(Tasks... tasks)
|
||||
{
|
||||
return internal::WhenAllAwaiter<Tasks...>(std::move(tasks)...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
internal::WhenAllAwaiter<std::vector<Task<T>>> when_all(
|
||||
std::vector<Task<T>> tasks)
|
||||
{
|
||||
return internal::WhenAllAwaiter(std::move(tasks));
|
||||
}
|
||||
|
||||
inline internal::WhenAllAwaiter<std::vector<Task<void>>> when_all(
|
||||
std::vector<Task<void>> tasks)
|
||||
{
|
||||
return internal::WhenAllAwaiter(std::move(tasks));
|
||||
}
|
||||
|
||||
} // namespace drogon
|
||||
|
@ -3,9 +3,14 @@
|
||||
#include <drogon/HttpAppFramework.h>
|
||||
#include <trantor/net/EventLoopThread.h>
|
||||
#include <trantor/net/EventLoopThreadPool.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <exception>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
using namespace drogon;
|
||||
@ -245,3 +250,81 @@ DROGON_TEST(Mutex)
|
||||
pool.getLoop(i)->quit();
|
||||
pool.wait();
|
||||
}
|
||||
|
||||
DROGON_TEST(WhenAll)
|
||||
{
|
||||
using TestCtx = std::shared_ptr<drogon::test::Case>;
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
size_t counter = 0;
|
||||
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.2);
|
||||
(*counter)++;
|
||||
}(TEST_CTX, &counter);
|
||||
auto t2 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.1);
|
||||
(*counter)++;
|
||||
}(TEST_CTX, &counter);
|
||||
std::vector<Task<void>> tasks;
|
||||
tasks.emplace_back(std::move(t1));
|
||||
tasks.emplace_back(std::move(t2));
|
||||
|
||||
co_await when_all(std::move(tasks));
|
||||
CHECK(counter == 2);
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
std::vector<Task<void>> tasks;
|
||||
co_await when_all(std::move(tasks));
|
||||
SUCCESS();
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
|
||||
auto t2 = [](TestCtx TEST_CTX) -> Task<int> { co_return 2; }(TEST_CTX);
|
||||
std::vector<Task<int>> tasks;
|
||||
tasks.emplace_back(std::move(t1));
|
||||
tasks.emplace_back(std::move(t2));
|
||||
|
||||
auto res = co_await when_all(std::move(tasks));
|
||||
CO_REQUIRE(res.size() == 2);
|
||||
CHECK(res[0] == 1);
|
||||
CHECK(res[1] == 2);
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
|
||||
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
|
||||
co_return "Hello";
|
||||
}(TEST_CTX);
|
||||
auto [num, str] = co_await when_all(std::move(t1), std::move(t2));
|
||||
CHECK(num == 1);
|
||||
CHECK(str == "Hello");
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
size_t counter = 0;
|
||||
// Even on corutine throws, other coroutins run to completion
|
||||
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<int> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.2);
|
||||
(*counter)++;
|
||||
co_return 1;
|
||||
}(TEST_CTX, &counter);
|
||||
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.1);
|
||||
throw std::runtime_error("Test exception");
|
||||
}(TEST_CTX);
|
||||
CO_REQUIRE_THROWS(co_await when_all(std::move(t1), std::move(t2)));
|
||||
CHECK(counter == 1);
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
size_t counter = 0;
|
||||
// void retuens gets mapped to std::false_type in the tuple API
|
||||
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
|
||||
(*counter)++;
|
||||
co_return;
|
||||
}(TEST_CTX, &counter);
|
||||
auto [res] = co_await when_all(std::move(t1));
|
||||
CHECK(counter == 1);
|
||||
}(TEST_CTX);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user