Compare commits

...

3 Commits

Author SHA1 Message Date
fantasy-peak
debdc27c4d
Merge 2beaf0a901b7bfc95c7da48bb2e9ca750cbd1a34 into 3221c4385714859ce20348306fb4c0f170e778dc 2025-07-02 18:04:15 +05:30
Martin Chang
3221c43857
implement when_all coroutine gate (#2342) 2025-07-02 10:32:45 +08:00
fantasy-peak
2beaf0a901 Add HttpClient pool 2024-11-07 15:47:09 +08:00
5 changed files with 696 additions and 23 deletions

View File

@ -718,6 +718,7 @@ install(FILES ${NOSQL_HEADERS} DESTINATION ${INSTALL_INCLUDE_DIR}/drogon/nosql)
set(DROGON_UTIL_HEADERS
lib/inc/drogon/utils/coroutine.h
lib/inc/drogon/utils/Http11ClientPool.h
lib/inc/drogon/utils/FunctionTraits.h
lib/inc/drogon/utils/HttpConstraint.h
lib/inc/drogon/utils/OStringStream.h

View File

@ -1,43 +1,112 @@
#include <drogon/drogon.h>
#include <future>
#include <chrono>
#include <iostream>
#include <memory>
#include <thread>
#include <drogon/HttpTypes.h>
#include <trantor/utils/Logger.h>
#ifdef __linux__
#include <sys/socket.h>
#include <netinet/tcp.h>
#endif
#include <drogon/utils/Http11ClientPool.h>
using namespace drogon;
int nth_resp = 0;
int main()
{
auto func = [](int fd) {
std::cout << "setSockOptCallback:" << fd << std::endl;
#ifdef __linux__
int optval = 10;
::setsockopt(fd,
SOL_TCP,
TCP_KEEPCNT,
&optval,
static_cast<socklen_t>(sizeof optval));
::setsockopt(fd,
SOL_TCP,
TCP_KEEPIDLE,
&optval,
static_cast<socklen_t>(sizeof optval));
::setsockopt(fd,
SOL_TCP,
TCP_KEEPINTVL,
&optval,
static_cast<socklen_t>(sizeof optval));
#endif
};
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
#ifdef __cpp_impl_coroutine
Http11ClientPoolConfig cfg{
.hostString = "http://www.baidu.com",
.useOldTLS = false,
.validateCert = false,
.size = 10,
.setCallback =
[func](auto &client) {
LOG_INFO << "setCallback";
client->setSockOptCallback(func);
},
.numOfThreads = 4,
.keepaliveRequests = 1000,
.idleTimeout = std::chrono::seconds(10),
.maxLifeTime = std::chrono::seconds(300),
.checkInterval = std::chrono::seconds(10),
};
auto pool = std::make_unique<Http11ClientPool>(cfg);
auto req = HttpRequest::newHttpRequest();
req->setMethod(drogon::Get);
req->setPath("/s");
req->setParameter("wd", "wx");
req->setParameter("oq", "wx");
for (int i = 0; i < 1; i++)
{
[](auto req, auto &pool) -> drogon::AsyncTask {
{
auto [result, resp] = co_await pool->sendRequestCoro(req, 10);
if (result == ReqResult::Ok)
LOG_INFO << "1:" << resp->getStatusCode();
}
{
auto [result, resp] = co_await pool->sendRequestCoro(req, 10);
if (result == ReqResult::Ok)
LOG_INFO << "2:" << resp->getStatusCode();
}
{
auto [result, resp] = co_await pool->sendRequestCoro(req, 10);
if (result == ReqResult::Ok)
LOG_INFO << "3:" << resp->getStatusCode();
}
co_return;
}(req, pool);
}
for (int i = 0; i < 10; i++)
{
pool->sendRequest(
req,
[](ReqResult result, const HttpResponsePtr &response) {
if (result != ReqResult::Ok)
{
LOG_ERROR
<< "error while sending request to server! result: "
<< result;
return;
}
LOG_INFO << "callback:" << response->getStatusCode();
},
10);
}
std::this_thread::sleep_for(std::chrono::seconds(30));
#else
{
auto client = HttpClient::newHttpClient("http://www.baidu.com");
client->setSockOptCallback([](int fd) {
std::cout << "setSockOptCallback:" << fd << std::endl;
#ifdef __linux__
int optval = 10;
::setsockopt(fd,
SOL_TCP,
TCP_KEEPCNT,
&optval,
static_cast<socklen_t>(sizeof optval));
::setsockopt(fd,
SOL_TCP,
TCP_KEEPIDLE,
&optval,
static_cast<socklen_t>(sizeof optval));
::setsockopt(fd,
SOL_TCP,
TCP_KEEPINTVL,
&optval,
static_cast<socklen_t>(sizeof optval));
#endif
});
client->setSockOptCallback(func);
auto req = HttpRequest::newHttpRequest();
req->setMethod(drogon::Get);
@ -77,4 +146,5 @@ int main()
}
app().run();
#endif
}

View File

@ -0,0 +1,312 @@
/**
*
* @file Http11ClientPool.h
* @author fantasy-peak
*
* Copyright 2024, fantasy-peak. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <cassert>
#include <chrono>
#include <cstddef>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <optional>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include <trantor/utils/Logger.h>
#include <trantor/net/EventLoopThreadPool.h>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
#include <drogon/drogon.h>
namespace drogon
{
struct Http11ClientPoolConfig
{
std::string hostString;
bool useOldTLS{false};
bool validateCert{false};
std::size_t size{100};
std::function<void(HttpClientPtr &)> setCallback;
std::size_t numOfThreads{std::thread::hardware_concurrency()};
std::optional<std::size_t> keepaliveRequests;
std::optional<std::chrono::seconds> idleTimeout;
std::optional<std::chrono::seconds> maxLifeTime;
std::optional<std::chrono::seconds> checkInterval;
};
class Http11ClientPool final
{
public:
Http11ClientPool(
Http11ClientPoolConfig cfg,
std::shared_ptr<trantor::EventLoopThreadPool> loopPool = nullptr)
: cfg_(std::move(cfg)), loopPool_(std::move(loopPool))
{
if (loopPool_ == nullptr)
{
loopPool_ = std::make_shared<trantor::EventLoopThreadPool>(
cfg_.numOfThreads);
loopPool_->start();
isSelfThreadPool_ = true;
}
else
{
isSelfThreadPool_ = false;
}
loopPtr_ = loopPool_->getNextLoop();
for (std::size_t i = 0; i < cfg_.size; i++)
{
auto loopPtr = loopPool_->getNextLoop();
auto func = [this, loopPtr]() mutable {
auto client = HttpClient::newHttpClient(cfg_.hostString,
loopPtr,
cfg_.useOldTLS,
cfg_.validateCert);
if (cfg_.setCallback)
cfg_.setCallback(client);
return client;
};
httpClients_.emplace(std::make_shared<Connection>(func, cfg_));
}
LOG_DEBUG << "httpClients_ size:" << httpClients_.size();
if (cfg_.idleTimeout.has_value() && cfg_.checkInterval.has_value())
{
timerId_ = loopPtr_->runEvery(cfg_.checkInterval.value(), [this] {
std::unique_lock<std::mutex> lock(mutex_);
if (httpClients_.empty())
return;
std::queue<std::shared_ptr<Connection>> clients;
while (!httpClients_.empty())
{
auto connPtr = std::move(httpClients_.front());
httpClients_.pop();
if (connPtr->reachIdleTimeout())
{
// close tcp connection
connPtr->resetClientPtr();
}
clients.emplace(std::move(connPtr));
}
httpClients_ = std::move(clients);
});
}
}
~Http11ClientPool()
{
if (timerId_)
{
std::promise<void> done;
loopPtr_->runInLoop([&] {
loopPtr_->invalidateTimer(timerId_.value());
done.set_value();
});
done.get_future().wait();
}
if (isSelfThreadPool_)
{
for (auto &ptr : loopPool_->getLoops())
ptr->runInLoop([=] { ptr->quit(); });
loopPool_->wait();
loopPool_.reset();
}
std::unique_lock<std::mutex> lock(mutex_);
std::queue<std::shared_ptr<Connection>> tmp;
httpClients_.swap(tmp);
}
Http11ClientPool(const Http11ClientPool &) = delete;
Http11ClientPool &operator=(const Http11ClientPool &) = delete;
Http11ClientPool(Http11ClientPool &&) = delete;
Http11ClientPool &operator=(Http11ClientPool &&) = delete;
void sendRequest(const HttpRequestPtr &req,
std::function<void(ReqResult, const HttpResponsePtr &)> cb,
double timeout = 0)
{
std::unique_lock<std::mutex> lock(mutex_);
if (httpClients_.empty())
{
httpRequest_.emplace(req, std::move(cb));
}
else
{
auto connPtr = std::move(httpClients_.front());
httpClients_.pop();
lock.unlock();
send(std::move(connPtr), req, std::move(cb), timeout);
}
return;
}
#ifdef __cpp_impl_coroutine
auto sendRequestCoro(HttpRequestPtr req, double timeout = 0)
{
struct Awaiter
: public CallbackAwaiter<std::tuple<ReqResult, HttpResponsePtr>>
{
Awaiter(Http11ClientPool *pool, HttpRequestPtr req, double timeout)
: pool_(pool), req_(std::move(req)), timeout_(timeout)
{
}
void await_suspend(std::coroutine_handle<> handle)
{
pool_->sendRequest(
req_,
[this, handle](ReqResult result,
const HttpResponsePtr &ptr) {
setValue(std::make_tuple(result, ptr));
handle.resume();
},
timeout_);
}
private:
Http11ClientPool *pool_;
HttpRequestPtr req_;
double timeout_;
};
return Awaiter{this, std::move(req), timeout};
}
#endif
private:
struct Connection
{
Connection(std::function<HttpClientPtr()> cb,
const Http11ClientPoolConfig &cfg)
: createHttpClientFunc_(std::move(cb)), cfg_(cfg)
{
init();
}
void init()
{
clientPtr_ = createHttpClientFunc_();
auto now = std::chrono::system_clock::now();
timePoint_ = now;
startTimePoint_ = now;
counter_ = 0;
}
void send(const HttpRequestPtr &req,
std::function<void(ReqResult, const HttpResponsePtr &)> cb,
double timeout)
{
if (isInvalid())
{
init();
}
assert(clientPtr_ != nullptr);
clientPtr_->sendRequest(req, std::move(cb), timeout);
++counter_;
auto now = std::chrono::system_clock::now();
timePoint_ = now;
}
bool isInvalid()
{
auto now = std::chrono::system_clock::now();
auto idleDut = now - timePoint_;
auto dut = now - startTimePoint_;
if ((clientPtr_ == nullptr) ||
(cfg_.keepaliveRequests.has_value() &&
counter_ >= cfg_.keepaliveRequests.value()) ||
(cfg_.idleTimeout.has_value() &&
idleDut >= cfg_.idleTimeout.value()) ||
(cfg_.maxLifeTime.has_value() &&
dut >= cfg_.maxLifeTime.value()))
{
return true;
}
return false;
}
bool reachIdleTimeout()
{
auto now = std::chrono::system_clock::now();
auto idleDut = now - timePoint_;
if (idleDut >= cfg_.idleTimeout.value())
{
return true;
}
return false;
}
void resetClientPtr()
{
clientPtr_ = nullptr;
}
Http11ClientPoolConfig cfg_;
std::function<HttpClientPtr()> createHttpClientFunc_;
HttpClientPtr clientPtr_;
std::chrono::time_point<std::chrono::system_clock> timePoint_;
std::chrono::time_point<std::chrono::system_clock> startTimePoint_;
std::size_t counter_{0};
};
void send(std::shared_ptr<Connection> connPtr,
const HttpRequestPtr &req,
std::function<void(ReqResult, const HttpResponsePtr &)> cb,
double timeout)
{
connPtr->send(
req,
[connPtr, this, cb = std::move(cb), timeout](
ReqResult result, const HttpResponsePtr &ptr) mutable {
cb(result, ptr);
std::unique_lock<std::mutex> lock(mutex_);
if (httpRequest_.empty())
{
httpClients_.emplace(std::move(connPtr));
}
else
{
auto op = std::move(httpRequest_.front());
httpRequest_.pop();
lock.unlock();
auto &[req, cb] = op;
send(std::move(connPtr), req, std::move(cb), timeout);
}
return;
},
timeout);
}
Http11ClientPoolConfig cfg_;
std::shared_ptr<trantor::EventLoopThreadPool> loopPool_;
bool isSelfThreadPool_;
trantor::EventLoop *loopPtr_;
std::mutex mutex_;
std::queue<std::shared_ptr<Connection>> httpClients_;
std::queue<
std::tuple<HttpRequestPtr,
std::function<void(ReqResult, const HttpResponsePtr &)>>>
httpRequest_;
std::optional<trantor::TimerId> timerId_;
};
} // namespace drogon

View File

@ -55,6 +55,10 @@ auto getAwaiter(T &&value) noexcept(
return getAwaiterImpl(static_cast<T &&>(value));
}
template <typename T>
using void_to_false_t =
std::conditional_t<std::is_same_v<T, void>, std::false_type, T>;
} // end namespace internal
template <typename T>
@ -420,6 +424,11 @@ struct CallbackAwaiter : public trantor::NonCopyable
return false;
}
bool hasException() const noexcept
{
return exception_ != nullptr;
}
const T &await_resume() const noexcept(false)
{
// await_resume() should always be called after co_await
@ -470,6 +479,11 @@ struct CallbackAwaiter<void> : public trantor::NonCopyable
std::rethrow_exception(exception_);
}
bool hasException() const noexcept
{
return exception_ != nullptr;
}
private:
std::exception_ptr exception_{nullptr};
@ -798,6 +812,180 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
std::function<T()> task_;
trantor::EventLoop *loop_;
};
template <typename... Tasks>
struct WhenAllAwaiter
: public CallbackAwaiter<
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...>>
{
WhenAllAwaiter(Tasks... tasks)
: tasks_(std::forward<Tasks>(tasks)...), counter_(sizeof...(tasks))
{
}
void await_suspend(std::coroutine_handle<> handle)
{
if (counter_ == 0)
{
handle.resume();
return;
}
await_suspend_impl(handle, std::index_sequence_for<Tasks...>{});
}
private:
std::tuple<Tasks...> tasks_;
std::atomic<size_t> counter_;
std::tuple<internal::void_to_false_t<await_result_t<Tasks>>...> results_;
std::atomic_flag exceptionFlag_;
template <size_t Idx>
void launch_task(std::coroutine_handle<> handle)
{
using Self = WhenAllAwaiter<Tasks...>;
[](Self *self, std::coroutine_handle<> handle) -> AsyncTask {
try
{
using TaskType = std::tuple_element_t<
Idx,
std::remove_cvref_t<decltype(results_)>>;
if constexpr (std::is_same_v<TaskType, std::false_type>)
{
co_await std::get<Idx>(self->tasks_);
std::get<Idx>(self->results_) = std::false_type{};
}
else
{
std::get<Idx>(self->results_) =
co_await std::get<Idx>(self->tasks_);
}
}
catch (...)
{
if (self->exceptionFlag_.test_and_set() == false)
self->setException(std::current_exception());
}
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
{
if (!self->hasException())
self->setValue(std::move(self->results_));
handle.resume();
}
}(this, handle);
}
template <size_t... Is>
void await_suspend_impl(std::coroutine_handle<> handle,
std::index_sequence<Is...>)
{
((launch_task<Is>(handle)), ...);
}
};
template <typename T>
struct WhenAllAwaiter<std::vector<Task<T>>>
: public CallbackAwaiter<std::vector<T>>
{
WhenAllAwaiter(std::vector<Task<T>> tasks)
: tasks_(std::move(tasks)),
counter_(tasks_.size()),
results_(tasks_.size())
{
}
void await_suspend(std::coroutine_handle<> handle)
{
if (tasks_.empty())
{
this->setValue(std::vector<T>{});
handle.resume();
return;
}
const size_t count = tasks_.size();
for (size_t i = 0; i < count; ++i)
{
[](WhenAllAwaiter *self,
std::coroutine_handle<> handle,
Task<T> task,
size_t index) -> AsyncTask {
try
{
auto result = co_await task;
self->results_[index] = std::move(result);
}
catch (...)
{
if (self->exceptionFlag_.test_and_set() == false)
self->setException(std::current_exception());
}
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
{
if (!self->hasException())
{
self->setValue(std::move(self->results_));
}
handle.resume();
}
}(this, handle, std::move(tasks_[i]), i);
}
}
private:
std::vector<Task<T>> tasks_;
std::atomic<size_t> counter_;
std::vector<T> results_;
std::atomic_flag exceptionFlag_;
};
template <>
struct WhenAllAwaiter<std::vector<Task<void>>> : public CallbackAwaiter<void>
{
WhenAllAwaiter(std::vector<Task<void>> &&t)
: tasks_(std::move(t)), counter_(tasks_.size())
{
}
void await_suspend(std::coroutine_handle<> handle)
{
if (tasks_.empty())
{
handle.resume();
return;
}
const size_t count =
tasks_
.size(); // capture the size fist (see lifetime comment beflow)
for (size_t i = 0; i < count; ++i)
{
[](WhenAllAwaiter *self,
std::coroutine_handle<> handle,
Task<> task) -> AsyncTask {
try
{
co_await task;
}
catch (...)
{
if (self->exceptionFlag_.test_and_set() == false)
self->setException(std::current_exception());
}
if (self->counter_.fetch_sub(1, std::memory_order_acq_rel) == 1)
// This line CAN delete `this` at last iteration. We MUST
// NOT depend on this after last iteration
handle.resume();
}(this, handle, std::move(tasks_[i]));
}
}
std::vector<Task<void>> tasks_;
std::atomic<size_t> counter_;
std::atomic_flag exceptionFlag_;
};
} // namespace internal
/**
@ -987,4 +1175,23 @@ class Mutex final
CoroMutexAwaiter *waiters_;
};
template <typename... Tasks>
internal::WhenAllAwaiter<Tasks...> when_all(Tasks... tasks)
{
return internal::WhenAllAwaiter<Tasks...>(std::move(tasks)...);
}
template <typename T>
internal::WhenAllAwaiter<std::vector<Task<T>>> when_all(
std::vector<Task<T>> tasks)
{
return internal::WhenAllAwaiter(std::move(tasks));
}
inline internal::WhenAllAwaiter<std::vector<Task<void>>> when_all(
std::vector<Task<void>> tasks)
{
return internal::WhenAllAwaiter(std::move(tasks));
}
} // namespace drogon

View File

@ -3,9 +3,14 @@
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <trantor/net/EventLoopThreadPool.h>
#include <atomic>
#include <chrono>
#include <cstdint>
#include <exception>
#include <future>
#include <memory>
#include <mutex>
#include <optional>
#include <type_traits>
using namespace drogon;
@ -245,3 +250,81 @@ DROGON_TEST(Mutex)
pool.getLoop(i)->quit();
pool.wait();
}
DROGON_TEST(WhenAll)
{
using TestCtx = std::shared_ptr<drogon::test::Case>;
[](TestCtx TEST_CTX) -> AsyncTask {
size_t counter = 0;
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
co_await drogon::sleepCoro(app().getLoop(), 0.2);
(*counter)++;
}(TEST_CTX, &counter);
auto t2 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
co_await drogon::sleepCoro(app().getLoop(), 0.1);
(*counter)++;
}(TEST_CTX, &counter);
std::vector<Task<void>> tasks;
tasks.emplace_back(std::move(t1));
tasks.emplace_back(std::move(t2));
co_await when_all(std::move(tasks));
CHECK(counter == 2);
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
std::vector<Task<void>> tasks;
co_await when_all(std::move(tasks));
SUCCESS();
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
auto t2 = [](TestCtx TEST_CTX) -> Task<int> { co_return 2; }(TEST_CTX);
std::vector<Task<int>> tasks;
tasks.emplace_back(std::move(t1));
tasks.emplace_back(std::move(t2));
auto res = co_await when_all(std::move(tasks));
CO_REQUIRE(res.size() == 2);
CHECK(res[0] == 1);
CHECK(res[1] == 2);
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
auto t1 = [](TestCtx TEST_CTX) -> Task<int> { co_return 1; }(TEST_CTX);
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
co_return "Hello";
}(TEST_CTX);
auto [num, str] = co_await when_all(std::move(t1), std::move(t2));
CHECK(num == 1);
CHECK(str == "Hello");
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
size_t counter = 0;
// Even on corutine throws, other coroutins run to completion
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<int> {
co_await drogon::sleepCoro(app().getLoop(), 0.2);
(*counter)++;
co_return 1;
}(TEST_CTX, &counter);
auto t2 = [](TestCtx TEST_CTX) -> Task<std::string> {
co_await drogon::sleepCoro(app().getLoop(), 0.1);
throw std::runtime_error("Test exception");
}(TEST_CTX);
CO_REQUIRE_THROWS(co_await when_all(std::move(t1), std::move(t2)));
CHECK(counter == 1);
}(TEST_CTX);
[](TestCtx TEST_CTX) -> AsyncTask {
size_t counter = 0;
// void retuens gets mapped to std::false_type in the tuple API
auto t1 = [](TestCtx TEST_CTX, size_t *counter) -> Task<> {
(*counter)++;
co_return;
}(TEST_CTX, &counter);
auto [res] = co_await when_all(std::move(t1));
CHECK(counter == 1);
}(TEST_CTX);
}