mirror of
https://github.com/drogonframework/drogon.git
synced 2025-07-04 00:00:46 -04:00
Compare commits
3 Commits
b52ab5f4f1
...
debdc27c4d
Author | SHA1 | Date | |
---|---|---|---|
|
debdc27c4d | ||
|
3221c43857 | ||
|
2beaf0a901 |
@ -718,6 +718,7 @@ install(FILES ${NOSQL_HEADERS} DESTINATION ${INSTALL_INCLUDE_DIR}/drogon/nosql)
|
||||
|
||||
set(DROGON_UTIL_HEADERS
|
||||
lib/inc/drogon/utils/coroutine.h
|
||||
lib/inc/drogon/utils/Http11ClientPool.h
|
||||
lib/inc/drogon/utils/FunctionTraits.h
|
||||
lib/inc/drogon/utils/HttpConstraint.h
|
||||
lib/inc/drogon/utils/OStringStream.h
|
||||
|
@ -1,43 +1,112 @@
|
||||
#include <drogon/drogon.h>
|
||||
|
||||
#include <future>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <drogon/HttpTypes.h>
|
||||
#include <trantor/utils/Logger.h>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/tcp.h>
|
||||
#endif
|
||||
|
||||
#include <drogon/utils/Http11ClientPool.h>
|
||||
using namespace drogon;
|
||||
|
||||
int nth_resp = 0;
|
||||
|
||||
int main()
|
||||
{
|
||||
auto func = [](int fd) {
|
||||
std::cout << "setSockOptCallback:" << fd << std::endl;
|
||||
#ifdef __linux__
|
||||
int optval = 10;
|
||||
::setsockopt(fd,
|
||||
SOL_TCP,
|
||||
TCP_KEEPCNT,
|
||||
&optval,
|
||||
static_cast<socklen_t>(sizeof optval));
|
||||
::setsockopt(fd,
|
||||
SOL_TCP,
|
||||
TCP_KEEPIDLE,
|
||||
&optval,
|
||||
static_cast<socklen_t>(sizeof optval));
|
||||
::setsockopt(fd,
|
||||
SOL_TCP,
|
||||
TCP_KEEPINTVL,
|
||||
&optval,
|
||||
static_cast<socklen_t>(sizeof optval));
|
||||
#endif
|
||||
};
|
||||
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
|
||||
#ifdef __cpp_impl_coroutine
|
||||
Http11ClientPoolConfig cfg{
|
||||
.hostString = "http://www.baidu.com",
|
||||
.useOldTLS = false,
|
||||
.validateCert = false,
|
||||
.size = 10,
|
||||
.setCallback =
|
||||
[func](auto &client) {
|
||||
LOG_INFO << "setCallback";
|
||||
client->setSockOptCallback(func);
|
||||
},
|
||||
.numOfThreads = 4,
|
||||
.keepaliveRequests = 1000,
|
||||
.idleTimeout = std::chrono::seconds(10),
|
||||
.maxLifeTime = std::chrono::seconds(300),
|
||||
.checkInterval = std::chrono::seconds(10),
|
||||
};
|
||||
auto pool = std::make_unique<Http11ClientPool>(cfg);
|
||||
auto req = HttpRequest::newHttpRequest();
|
||||
req->setMethod(drogon::Get);
|
||||
req->setPath("/s");
|
||||
req->setParameter("wd", "wx");
|
||||
req->setParameter("oq", "wx");
|
||||
|
||||
for (int i = 0; i < 1; i++)
|
||||
{
|
||||
[](auto req, auto &pool) -> drogon::AsyncTask {
|
||||
{
|
||||
auto [result, resp] = co_await pool->sendRequestCoro(req, 10);
|
||||
if (result == ReqResult::Ok)
|
||||
LOG_INFO << "1:" << resp->getStatusCode();
|
||||
}
|
||||
{
|
||||
auto [result, resp] = co_await pool->sendRequestCoro(req, 10);
|
||||
if (result == ReqResult::Ok)
|
||||
LOG_INFO << "2:" << resp->getStatusCode();
|
||||
}
|
||||
{
|
||||
auto [result, resp] = co_await pool->sendRequestCoro(req, 10);
|
||||
if (result == ReqResult::Ok)
|
||||
LOG_INFO << "3:" << resp->getStatusCode();
|
||||
}
|
||||
co_return;
|
||||
}(req, pool);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 10; i++)
|
||||
{
|
||||
pool->sendRequest(
|
||||
req,
|
||||
[](ReqResult result, const HttpResponsePtr &response) {
|
||||
if (result != ReqResult::Ok)
|
||||
{
|
||||
LOG_ERROR
|
||||
<< "error while sending request to server! result: "
|
||||
<< result;
|
||||
return;
|
||||
}
|
||||
LOG_INFO << "callback:" << response->getStatusCode();
|
||||
},
|
||||
10);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(30));
|
||||
#else
|
||||
{
|
||||
auto client = HttpClient::newHttpClient("http://www.baidu.com");
|
||||
client->setSockOptCallback([](int fd) {
|
||||
std::cout << "setSockOptCallback:" << fd << std::endl;
|
||||
#ifdef __linux__
|
||||
int optval = 10;
|
||||
::setsockopt(fd,
|
||||
SOL_TCP,
|
||||
TCP_KEEPCNT,
|
||||
&optval,
|
||||
static_cast<socklen_t>(sizeof optval));
|
||||
::setsockopt(fd,
|
||||
SOL_TCP,
|
||||
TCP_KEEPIDLE,
|
||||
&optval,
|
||||
static_cast<socklen_t>(sizeof optval));
|
||||
::setsockopt(fd,
|
||||
SOL_TCP,
|
||||
TCP_KEEPINTVL,
|
||||
&optval,
|
||||
static_cast<socklen_t>(sizeof optval));
|
||||
#endif
|
||||
});
|
||||
client->setSockOptCallback(func);
|
||||
|
||||
auto req = HttpRequest::newHttpRequest();
|
||||
req->setMethod(drogon::Get);
|
||||
@ -77,4 +146,5 @@ int main()
|
||||
}
|
||||
|
||||
app().run();
|
||||
#endif
|
||||
}
|
||||
|
312
lib/inc/drogon/utils/Http11ClientPool.h
Normal file
312
lib/inc/drogon/utils/Http11ClientPool.h
Normal file
@ -0,0 +1,312 @@
|
||||
/**
|
||||
*
|
||||
* @file Http11ClientPool.h
|
||||
* @author fantasy-peak
|
||||
*
|
||||
* Copyright 2024, fantasy-peak. All rights reserved.
|
||||
* https://github.com/an-tao/drogon
|
||||
* Use of this source code is governed by a MIT license
|
||||
* that can be found in the License file.
|
||||
*
|
||||
* Drogon
|
||||
*
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include <trantor/utils/Logger.h>
|
||||
#include <trantor/net/EventLoopThreadPool.h>
|
||||
#ifdef __cpp_impl_coroutine
|
||||
#include <drogon/utils/coroutine.h>
|
||||
#endif
|
||||
#include <drogon/drogon.h>
|
||||
|
||||
namespace drogon
|
||||
{
|
||||
|
||||
struct Http11ClientPoolConfig
|
||||
{
|
||||
std::string hostString;
|
||||
bool useOldTLS{false};
|
||||
bool validateCert{false};
|
||||
std::size_t size{100};
|
||||
std::function<void(HttpClientPtr &)> setCallback;
|
||||
std::size_t numOfThreads{std::thread::hardware_concurrency()};
|
||||
std::optional<std::size_t> keepaliveRequests;
|
||||
std::optional<std::chrono::seconds> idleTimeout;
|
||||
std::optional<std::chrono::seconds> maxLifeTime;
|
||||
std::optional<std::chrono::seconds> checkInterval;
|
||||
};
|
||||
|
||||
class Http11ClientPool final
|
||||
{
|
||||
public:
|
||||
Http11ClientPool(
|
||||
Http11ClientPoolConfig cfg,
|
||||
std::shared_ptr<trantor::EventLoopThreadPool> loopPool = nullptr)
|
||||
: cfg_(std::move(cfg)), loopPool_(std::move(loopPool))
|
||||
{
|
||||
if (loopPool_ == nullptr)
|
||||
{
|
||||
loopPool_ = std::make_shared<trantor::EventLoopThreadPool>(
|
||||
cfg_.numOfThreads);
|
||||
loopPool_->start();
|
||||
isSelfThreadPool_ = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
isSelfThreadPool_ = false;
|
||||
}
|
||||
loopPtr_ = loopPool_->getNextLoop();
|
||||
|
||||
for (std::size_t i = 0; i < cfg_.size; i++)
|
||||
{
|
||||
auto loopPtr = loopPool_->getNextLoop();
|
||||
auto func = [this, loopPtr]() mutable {
|
||||
auto client = HttpClient::newHttpClient(cfg_.hostString,
|
||||
loopPtr,
|
||||
cfg_.useOldTLS,
|
||||
cfg_.validateCert);
|
||||
if (cfg_.setCallback)
|
||||
cfg_.setCallback(client);
|
||||
return client;
|
||||
};
|
||||
httpClients_.emplace(std::make_shared<Connection>(func, cfg_));
|
||||
}
|
||||
LOG_DEBUG << "httpClients_ size:" << httpClients_.size();
|
||||
|
||||
if (cfg_.idleTimeout.has_value() && cfg_.checkInterval.has_value())
|
||||
{
|
||||
timerId_ = loopPtr_->runEvery(cfg_.checkInterval.value(), [this] {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (httpClients_.empty())
|
||||
return;
|
||||
std::queue<std::shared_ptr<Connection>> clients;
|
||||
while (!httpClients_.empty())
|
||||
{
|
||||
auto connPtr = std::move(httpClients_.front());
|
||||
httpClients_.pop();
|
||||
if (connPtr->reachIdleTimeout())
|
||||
{
|
||||
// close tcp connection
|
||||
connPtr->resetClientPtr();
|
||||
}
|
||||
clients.emplace(std::move(connPtr));
|
||||
}
|
||||
httpClients_ = std::move(clients);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
~Http11ClientPool()
|
||||
{
|
||||
if (timerId_)
|
||||
{
|
||||
std::promise<void> done;
|
||||
loopPtr_->runInLoop([&] {
|
||||
loopPtr_->invalidateTimer(timerId_.value());
|
||||
done.set_value();
|
||||
});
|
||||
done.get_future().wait();
|
||||
}
|
||||
if (isSelfThreadPool_)
|
||||
{
|
||||
for (auto &ptr : loopPool_->getLoops())
|
||||
ptr->runInLoop([=] { ptr->quit(); });
|
||||
loopPool_->wait();
|
||||
loopPool_.reset();
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
std::queue<std::shared_ptr<Connection>> tmp;
|
||||
httpClients_.swap(tmp);
|
||||
}
|
||||
|
||||
Http11ClientPool(const Http11ClientPool &) = delete;
|
||||
Http11ClientPool &operator=(const Http11ClientPool &) = delete;
|
||||
Http11ClientPool(Http11ClientPool &&) = delete;
|
||||
Http11ClientPool &operator=(Http11ClientPool &&) = delete;
|
||||
|
||||
void sendRequest(const HttpRequestPtr &req,
|
||||
std::function<void(ReqResult, const HttpResponsePtr &)> cb,
|
||||
double timeout = 0)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (httpClients_.empty())
|
||||
{
|
||||
httpRequest_.emplace(req, std::move(cb));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto connPtr = std::move(httpClients_.front());
|
||||
httpClients_.pop();
|
||||
lock.unlock();
|
||||
send(std::move(connPtr), req, std::move(cb), timeout);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef __cpp_impl_coroutine
|
||||
|
||||
auto sendRequestCoro(HttpRequestPtr req, double timeout = 0)
|
||||
{
|
||||
struct Awaiter
|
||||
: public CallbackAwaiter<std::tuple<ReqResult, HttpResponsePtr>>
|
||||
{
|
||||
Awaiter(Http11ClientPool *pool, HttpRequestPtr req, double timeout)
|
||||
: pool_(pool), req_(std::move(req)), timeout_(timeout)
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
pool_->sendRequest(
|
||||
req_,
|
||||
[this, handle](ReqResult result,
|
||||
const HttpResponsePtr &ptr) {
|
||||
setValue(std::make_tuple(result, ptr));
|
||||
handle.resume();
|
||||
},
|
||||
timeout_);
|
||||
}
|
||||
|
||||
private:
|
||||
Http11ClientPool *pool_;
|
||||
HttpRequestPtr req_;
|
||||
double timeout_;
|
||||
};
|
||||
|
||||
return Awaiter{this, std::move(req), timeout};
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
private:
|
||||
struct Connection
|
||||
{
|
||||
Connection(std::function<HttpClientPtr()> cb,
|
||||
const Http11ClientPoolConfig &cfg)
|
||||
: createHttpClientFunc_(std::move(cb)), cfg_(cfg)
|
||||
{
|
||||
init();
|
||||
}
|
||||
|
||||
void init()
|
||||
{
|
||||
clientPtr_ = createHttpClientFunc_();
|
||||
auto now = std::chrono::system_clock::now();
|
||||
timePoint_ = now;
|
||||
startTimePoint_ = now;
|
||||
counter_ = 0;
|
||||
}
|
||||
|
||||
void send(const HttpRequestPtr &req,
|
||||
std::function<void(ReqResult, const HttpResponsePtr &)> cb,
|
||||
double timeout)
|
||||
{
|
||||
if (isInvalid())
|
||||
{
|
||||
init();
|
||||
}
|
||||
assert(clientPtr_ != nullptr);
|
||||
clientPtr_->sendRequest(req, std::move(cb), timeout);
|
||||
++counter_;
|
||||
auto now = std::chrono::system_clock::now();
|
||||
timePoint_ = now;
|
||||
}
|
||||
|
||||
bool isInvalid()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto idleDut = now - timePoint_;
|
||||
auto dut = now - startTimePoint_;
|
||||
if ((clientPtr_ == nullptr) ||
|
||||
(cfg_.keepaliveRequests.has_value() &&
|
||||
counter_ >= cfg_.keepaliveRequests.value()) ||
|
||||
(cfg_.idleTimeout.has_value() &&
|
||||
idleDut >= cfg_.idleTimeout.value()) ||
|
||||
(cfg_.maxLifeTime.has_value() &&
|
||||
dut >= cfg_.maxLifeTime.value()))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool reachIdleTimeout()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto idleDut = now - timePoint_;
|
||||
if (idleDut >= cfg_.idleTimeout.value())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void resetClientPtr()
|
||||
{
|
||||
clientPtr_ = nullptr;
|
||||
}
|
||||
|
||||
Http11ClientPoolConfig cfg_;
|
||||
std::function<HttpClientPtr()> createHttpClientFunc_;
|
||||
HttpClientPtr clientPtr_;
|
||||
std::chrono::time_point<std::chrono::system_clock> timePoint_;
|
||||
std::chrono::time_point<std::chrono::system_clock> startTimePoint_;
|
||||
std::size_t counter_{0};
|
||||
};
|
||||
|
||||
void send(std::shared_ptr<Connection> connPtr,
|
||||
const HttpRequestPtr &req,
|
||||
std::function<void(ReqResult, const HttpResponsePtr &)> cb,
|
||||
double timeout)
|
||||
{
|
||||
connPtr->send(
|
||||
req,
|
||||
[connPtr, this, cb = std::move(cb), timeout](
|
||||
ReqResult result, const HttpResponsePtr &ptr) mutable {
|
||||
cb(result, ptr);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (httpRequest_.empty())
|
||||
{
|
||||
httpClients_.emplace(std::move(connPtr));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto op = std::move(httpRequest_.front());
|
||||
httpRequest_.pop();
|
||||
lock.unlock();
|
||||
auto &[req, cb] = op;
|
||||
send(std::move(connPtr), req, std::move(cb), timeout);
|
||||
}
|
||||
return;
|
||||
},
|
||||
timeout);
|
||||
}
|
||||
|
||||
Http11ClientPoolConfig cfg_;
|
||||
std::shared_ptr<trantor::EventLoopThreadPool> loopPool_;
|
||||
bool isSelfThreadPool_;
|
||||
trantor::EventLoop *loopPtr_;
|
||||
std::mutex mutex_;
|
||||
std::queue<std::shared_ptr<Connection>> httpClients_;
|
||||
std::queue<
|
||||
std::tuple<HttpRequestPtr,
|
||||
std::function<void(ReqResult, const HttpResponsePtr &)>>>
|
||||
httpRequest_;
|
||||
std::optional<trantor::TimerId> timerId_;
|
||||
};
|
||||
|
||||
} // namespace drogon
|
@ -55,6 +55,10 @@ auto getAwaiter(T &&value) noexcept(
|
||||
return getAwaiterImpl(static_cast<T &&>(value));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using void_to_false_t =
|
||||
std::conditional_t<std::is_same_v<T, void>, std::false_type, T>;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
template <typename T>
|
||||
@ -420,6 +424,11 @@ struct CallbackAwaiter : public trantor::NonCopyable
|
||||
return false;
|
||||
}
|
||||
|
||||
bool hasException() const noexcept
|
||||
{
|
||||
return exception_ != nullptr;
|
||||
}
|
||||
|
||||
const T &await_resume() const noexcept(false)
|
||||
{
|
||||
// await_resume() should always be called after co_await
|
||||
@ -470,6 +479,11 @@ struct CallbackAwaiter<void> : public trantor::NonCopyable
|
||||
std::rethrow_exception(exception_);
|
||||
}
|
||||
|
||||
bool hasException() const noexcept
|
||||
{
|
||||
return exception_ != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::exception_ptr exception_{nullptr};
|
||||
|
||||
@ -798,6 +812,180 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
|
||||
std::function<T()> task_;
|
||||
trantor::EventLoop *loop_;
|
||||
};
|
||||
|
||||
template <typename... Tasks>
|
||||
struct WhenAllAwaiter
|
||||
: public CallbackAwaiter<
|
||||
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...>>
|
||||
{
|
||||
WhenAllAwaiter(Tasks... tasks)
|
||||
: tasks_(std::forward<Tasks>(tasks)...), counter_(sizeof...(tasks))
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (counter_ == 0)
|
||||
{
|
||||
handle.resume();
|
||||
return;
|
||||
}
|
||||
|
||||
await_suspend_impl(handle, std::index_sequence_for<Tasks...>{});
|
||||
}
|
||||
|
||||
private:
|
||||
std::tuple<Tasks...> tasks_;
|
||||
std::atomic<size_t> counter_;
|
||||
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...> results_;
|
||||
std::atomic_flag exceptionFlag_;
|
||||
|
||||
template <size_t Idx>
|
||||
void launch_task(std::coroutine_handle<> handle)
|
||||
{
|
||||
using Self = WhenAllAwaiter<Tasks...>;
|
||||
[](Self *self, std::coroutine_handle<> handle) -> AsyncTask {
|
||||
try
|
||||
{
|
||||
using TaskType = std::tuple_element_t<
|
||||
Idx,
|
||||
std::remove_cvref_t<decltype(results_)>>;
|
||||
if constexpr (std::is_same_v<TaskType, std::false_type>)
|
||||
{
|
||||
co_await std::get<Idx>(self->tasks_);
|
||||
std::get<Idx>(self->results_) = std::false_type{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::get<Idx>(self->results_) =
|
||||
co_await std::get<Idx>(self->tasks_);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (self->exceptionFlag_.test_and_set() == false)
|
||||
self->setException(std::current_exception());
|
||||
}
|
||||
|
||||
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
|
||||
{
|
||||
if (!self->hasException())
|
||||
self->setValue(std::move(self->results_));
|
||||
handle.resume();
|
||||
}
|
||||
}(this, handle);
|
||||
}
|
||||
|
||||
template <size_t... Is>
|
||||
void await_suspend_impl(std::coroutine_handle<> handle,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
((launch_task<Is>(handle)), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct WhenAllAwaiter<std::vector<Task<T>>>
|
||||
: public CallbackAwaiter<std::vector<T>>
|
||||
{
|
||||
WhenAllAwaiter(std::vector<Task<T>> tasks)
|
||||
: tasks_(std::move(tasks)),
|
||||
counter_(tasks_.size()),
|
||||
results_(tasks_.size())
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (tasks_.empty())
|
||||
{
|
||||
this->setValue(std::vector<T>{});
|
||||
handle.resume();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t count = tasks_.size();
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
[](WhenAllAwaiter *self,
|
||||
std::coroutine_handle<> handle,
|
||||
Task<T> task,
|
||||
size_t index) -> AsyncTask {
|
||||
try
|
||||
{
|
||||
auto result = co_await task;
|
||||
self->results_[index] = std::move(result);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (self->exceptionFlag_.test_and_set() == false)
|
||||
self->setException(std::current_exception());
|
||||
}
|
||||
|
||||
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
|
||||
{
|
||||
if (!self->hasException())
|
||||
{
|
||||
self->setValue(std::move(self->results_));
|
||||
}
|
||||
handle.resume();
|
||||
}
|
||||
}(this, handle, std::move(tasks_[i]), i);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Task<T>> tasks_;
|
||||
std::atomic<size_t> counter_;
|
||||
std::vector<T> results_;
|
||||
std::atomic_flag exceptionFlag_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WhenAllAwaiter<std::vector<Task<void>>> : public CallbackAwaiter<void>
|
||||
{
|
||||
WhenAllAwaiter(std::vector<Task<void>> &&t)
|
||||
: tasks_(std::move(t)), counter_(tasks_.size())
|
||||
{
|
||||
}
|
||||
|
||||
void await_suspend(std::coroutine_handle<> handle)
|
||||
{
|
||||
if (tasks_.empty())
|
||||
{
|
||||
handle.resume();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t count =
|
||||
tasks_
|
||||
.size(); // capture the size fist (see lifetime comment beflow)
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
{
|
||||
[](WhenAllAwaiter *self,
|
||||
std::coroutine_handle<> handle,
|
||||
Task<> task) -> AsyncTask {
|
||||
try
|
||||
{
|
||||
co_await task;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (self->exceptionFlag_.test_and_set() == false)
|
||||
self->setException(std::current_exception());
|
||||
}
|
||||
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
|
||||
// This line CAN delete `this` at last iteration. We MUST
|
||||
// NOT depend on this after last iteration
|
||||
handle.resume();
|
||||
}(this, handle, std::move(tasks_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Task<void>> tasks_;
|
||||
std::atomic<size_t> counter_;
|
||||
std::atomic_flag exceptionFlag_;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
/**
|
||||
@ -987,4 +1175,23 @@ class Mutex final
|
||||
CoroMutexAwaiter *waiters_;
|
||||
};
|
||||
|
||||
template <typename... Tasks>
|
||||
internal::WhenAllAwaiter<Tasks...> when_all(Tasks... tasks)
|
||||
{
|
||||
return internal::WhenAllAwaiter<Tasks...>(std::move(tasks)...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
internal::WhenAllAwaiter<std::vector<Task<T>>> when_all(
|
||||
std::vector<Task<T>> tasks)
|
||||
{
|
||||
return internal::WhenAllAwaiter(std::move(tasks));
|
||||
}
|
||||
|
||||
inline internal::WhenAllAwaiter<std::vector<Task<void>>> when_all(
|
||||
std::vector<Task<void>> tasks)
|
||||
{
|
||||
return internal::WhenAllAwaiter(std::move(tasks));
|
||||
}
|
||||
|
||||
} // namespace drogon
|
||||
|
@ -3,9 +3,14 @@
|
||||
#include <drogon/HttpAppFramework.h>
|
||||
#include <trantor/net/EventLoopThread.h>
|
||||
#include <trantor/net/EventLoopThreadPool.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <exception>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
using namespace drogon;
|
||||
@ -245,3 +250,81 @@ DROGON_TEST(Mutex)
|
||||
pool.getLoop(i)->quit();
|
||||
pool.wait();
|
||||
}
|
||||
|
||||
DROGON_TEST(WhenAll)
|
||||
{
|
||||
using TestCtx = std::shared_ptr<drogon::test::Case>;
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
size_t counter = 0;
|
||||
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.2);
|
||||
(*counter)++;
|
||||
}(TEST_CTX, &counter);
|
||||
auto t2 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.1);
|
||||
(*counter)++;
|
||||
}(TEST_CTX, &counter);
|
||||
std::vector<Task<void>> tasks;
|
||||
tasks.emplace_back(std::move(t1));
|
||||
tasks.emplace_back(std::move(t2));
|
||||
|
||||
co_await when_all(std::move(tasks));
|
||||
CHECK(counter == 2);
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
std::vector<Task<void>> tasks;
|
||||
co_await when_all(std::move(tasks));
|
||||
SUCCESS();
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
|
||||
auto t2 = [](TestCtx TEST_CTX) -> Task<int> { co_return 2; }(TEST_CTX);
|
||||
std::vector<Task<int>> tasks;
|
||||
tasks.emplace_back(std::move(t1));
|
||||
tasks.emplace_back(std::move(t2));
|
||||
|
||||
auto res = co_await when_all(std::move(tasks));
|
||||
CO_REQUIRE(res.size() == 2);
|
||||
CHECK(res[0] == 1);
|
||||
CHECK(res[1] == 2);
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
|
||||
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
|
||||
co_return "Hello";
|
||||
}(TEST_CTX);
|
||||
auto [num, str] = co_await when_all(std::move(t1), std::move(t2));
|
||||
CHECK(num == 1);
|
||||
CHECK(str == "Hello");
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
size_t counter = 0;
|
||||
// Even on corutine throws, other coroutins run to completion
|
||||
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<int> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.2);
|
||||
(*counter)++;
|
||||
co_return 1;
|
||||
}(TEST_CTX, &counter);
|
||||
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
|
||||
co_await drogon::sleepCoro(app().getLoop(), 0.1);
|
||||
throw std::runtime_error("Test exception");
|
||||
}(TEST_CTX);
|
||||
CO_REQUIRE_THROWS(co_await when_all(std::move(t1), std::move(t2)));
|
||||
CHECK(counter == 1);
|
||||
}(TEST_CTX);
|
||||
|
||||
[](TestCtx TEST_CTX) -> AsyncTask {
|
||||
size_t counter = 0;
|
||||
// void retuens gets mapped to std::false_type in the tuple API
|
||||
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
|
||||
(*counter)++;
|
||||
co_return;
|
||||
}(TEST_CTX, &counter);
|
||||
auto [res] = co_await when_all(std::move(t1));
|
||||
CHECK(counter == 1);
|
||||
}(TEST_CTX);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user