Compare commits

..

2 Commits

Author SHA1 Message Date
fantasy-peak
debdc27c4d
Merge 2beaf0a901b7bfc95c7da48bb2e9ca750cbd1a34 into 3221c4385714859ce20348306fb4c0f170e778dc 2025-07-02 18:04:15 +05:30
Martin Chang
3221c43857
implement when_all coroutine gate (#2342) 2025-07-02 10:32:45 +08:00
2 changed files with 290 additions and 0 deletions

View File

@ -55,6 +55,10 @@ auto getAwaiter(T &&value) noexcept(
return getAwaiterImpl(static_cast<T &&>(value));
}
template <typename T>
using void_to_false_t =
std::conditional_t<std::is_same_v<T, void>, std::false_type, T>;
} // end namespace internal
template <typename T>
@ -420,6 +424,11 @@ struct CallbackAwaiter : public trantor::NonCopyable
return false;
}
bool hasException() const noexcept
{
return exception_ != nullptr;
}
const T &await_resume() const noexcept(false)
{
// await_resume() should always be called after co_await
@ -470,6 +479,11 @@ struct CallbackAwaiter<void> : public trantor::NonCopyable
std::rethrow_exception(exception_);
}
bool hasException() const noexcept
{
return exception_ != nullptr;
}
private:
std::exception_ptr exception_{nullptr};
@ -798,6 +812,180 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
std::function<T()> task_;
trantor::EventLoop *loop_;
};
template <typename... Tasks>
struct WhenAllAwaiter
: public CallbackAwaiter<
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...>>
{
WhenAllAwaiter(Tasks... tasks)
: tasks_(std::forward<Tasks>(tasks)...), counter_(sizeof...(tasks))
{
}
void await_suspend(std::coroutine_handle<> handle)
{
if (counter_ == 0)
{
handle.resume();
return;
}
await_suspend_impl(handle, std::index_sequence_for<Tasks...>{});
}
private:
std::tuple<Tasks...> tasks_;
std::atomic<size_t> counter_;
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...> results_;
std::atomic_flag exceptionFlag_;
template <size_t Idx>
void launch_task(std::coroutine_handle<> handle)
{
using Self = WhenAllAwaiter<Tasks...>;
[](Self *self, std::coroutine_handle<> handle) -> AsyncTask {
try
{
using TaskType = std::tuple_element_t<
Idx,
std::remove_cvref_t<decltype(results_)>>;
if constexpr (std::is_same_v<TaskType, std::false_type>)
{
co_await std::get<Idx>(self->tasks_);
std::get<Idx>(self->results_) = std::false_type{};
}
else
{
std::get<Idx>(self->results_) =
co_await std::get<Idx>(self->tasks_);
}
}
catch (...)
{
if (self->exceptionFlag_.test_and_set() == false)
self->setException(std::current_exception());
}
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
{
if (!self->hasException())
self->setValue(std::move(self->results_));
handle.resume();
}
}(this, handle);
}
template <size_t... Is>
void await_suspend_impl(std::coroutine_handle<> handle,
std::index_sequence<Is...>)
{
((launch_task<Is>(handle)), ...);
}
};
template <typename T>
struct WhenAllAwaiter<std::vector<Task<T>>>
: public CallbackAwaiter<std::vector<T>>
{
WhenAllAwaiter(std::vector<Task<T>> tasks)
: tasks_(std::move(tasks)),
counter_(tasks_.size()),
results_(tasks_.size())
{
}
void await_suspend(std::coroutine_handle<> handle)
{
if (tasks_.empty())
{
this->setValue(std::vector<T>{});
handle.resume();
return;
}
const size_t count = tasks_.size();
for (size_t i = 0; i < count; ++i)
{
[](WhenAllAwaiter *self,
std::coroutine_handle<> handle,
Task<T> task,
size_t index) -> AsyncTask {
try
{
auto result = co_await task;
self->results_[index] = std::move(result);
}
catch (...)
{
if (self->exceptionFlag_.test_and_set() == false)
self->setException(std::current_exception());
}
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
{
if (!self->hasException())
{
self->setValue(std::move(self->results_));
}
handle.resume();
}
}(this, handle, std::move(tasks_[i]), i);
}
}
private:
std::vector<Task<T>> tasks_;
std::atomic<size_t> counter_;
std::vector<T> results_;
std::atomic_flag exceptionFlag_;
};
template <>
struct WhenAllAwaiter<std::vector<Task<void>>> : public CallbackAwaiter<void>
{
WhenAllAwaiter(std::vector<Task<void>> &&t)
: tasks_(std::move(t)), counter_(tasks_.size())
{
}
void await_suspend(std::coroutine_handle<> handle)
{
if (tasks_.empty())
{
handle.resume();
return;
}
const size_t count =
tasks_
.size(); // capture the size fist (see lifetime comment beflow)
for (size_t i = 0; i < count; ++i)
{
[](WhenAllAwaiter *self,
std::coroutine_handle<> handle,
Task<> task) -> AsyncTask {
try
{
co_await task;
}
catch (...)
{
if (self->exceptionFlag_.test_and_set() == false)
self->setException(std::current_exception());
}
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
// This line CAN delete `this` at last iteration. We MUST
// NOT depend on this after last iteration
handle.resume();
}(this, handle, std::move(tasks_[i]));
}
}
std::vector<Task<void>> tasks_;
std::atomic<size_t> counter_;
std::atomic_flag exceptionFlag_;
};
} // namespace internal
/**
@ -987,4 +1175,23 @@ class Mutex final
CoroMutexAwaiter *waiters_;
};
template <typename... Tasks>
internal::WhenAllAwaiter<Tasks...> when_all(Tasks... tasks)
{
return internal::WhenAllAwaiter<Tasks...>(std::move(tasks)...);
}
template <typename T>
internal::WhenAllAwaiter<std::vector<Task<T>>> when_all(
std::vector<Task<T>> tasks)
{
return internal::WhenAllAwaiter(std::move(tasks));
}
inline internal::WhenAllAwaiter<std::vector<Task<void>>> when_all(
std::vector<Task<void>> tasks)
{
return internal::WhenAllAwaiter(std::move(tasks));
}
} // namespace drogon

View File

@ -3,9 +3,14 @@
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <trantor/net/EventLoopThreadPool.h>
#include <atomic>
#include <chrono>
#include <cstdint>
#include <exception>
#include <future>
#include <memory>
#include <mutex>
#include <optional>
#include <type_traits>
using namespace drogon;
@ -245,3 +250,81 @@ DROGON_TEST(Mutex)
pool.getLoop(i)->quit();
pool.wait();
}
DROGON_TEST(WhenAll)
{
using TestCtx = std::shared_ptr<drogon::test::Case>;
[](TestCtx TEST_CTX) -> AsyncTask {
size_t counter = 0;
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
co_await drogon::sleepCoro(app().getLoop(), 0.2);
(*counter)++;
}(TEST_CTX, &counter);
auto t2 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
co_await drogon::sleepCoro(app().getLoop(), 0.1);
(*counter)++;
}(TEST_CTX, &counter);
std::vector<Task<void>> tasks;
tasks.emplace_back(std::move(t1));
tasks.emplace_back(std::move(t2));
co_await when_all(std::move(tasks));
CHECK(counter == 2);
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
std::vector<Task<void>> tasks;
co_await when_all(std::move(tasks));
SUCCESS();
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
auto t2 = [](TestCtx TEST_CTX) -> Task<int> { co_return 2; }(TEST_CTX);
std::vector<Task<int>> tasks;
tasks.emplace_back(std::move(t1));
tasks.emplace_back(std::move(t2));
auto res = co_await when_all(std::move(tasks));
CO_REQUIRE(res.size() == 2);
CHECK(res[0] == 1);
CHECK(res[1] == 2);
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
co_return "Hello";
}(TEST_CTX);
auto [num, str] = co_await when_all(std::move(t1), std::move(t2));
CHECK(num == 1);
CHECK(str == "Hello");
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
size_t counter = 0;
// Even on corutine throws, other coroutins run to completion
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<int> {
co_await drogon::sleepCoro(app().getLoop(), 0.2);
(*counter)++;
co_return 1;
}(TEST_CTX, &counter);
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
co_await drogon::sleepCoro(app().getLoop(), 0.1);
throw std::runtime_error("Test exception");
}(TEST_CTX);
CO_REQUIRE_THROWS(co_await when_all(std::move(t1), std::move(t2)));
CHECK(counter == 1);
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
size_t counter = 0;
// void retuens gets mapped to std::false_type in the tuple API
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
(*counter)++;
co_return;
}(TEST_CTX, &counter);
auto [res] = co_await when_all(std::move(t1));
CHECK(counter == 1);
}(TEST_CTX);
}