Coroutine support (#693)

* app().registerHttpHandler() accepts coroutine as handlers
* HttpController can use coroutine as handlers'
* Http request handlers with coroutine catches exception instead of crashing the entire app
* DbClient now has execSqlCoro that is awaitable
* DbClient now has newTransactionCoro that is awaitable
* HttpClient have awaitable sendRequestCoro
* WebSocketClient have awaitable connectToServerCoro
* WebSocketClient have setAsyncMessageHandler and setAsyncConnectionClosedHandler
* drogon::AsyncTask and drogon::Task<T> as our corutine types
* Related tests
* Misc

Future work
* Coroutine for WebSocket server
* Known issues

co_future() and sync_wait may crash. It looks like GCC bug but I'm not sure.
Workarround: Make an coroutine of AsyncTask. Then launch said coroutine.
Not sure why wrapping the exact same thing in function crashes things.

Co-authored-by: an-tao <antao2002@gmail.com>
This commit is contained in:
Martin Chang 2021-02-06 17:05:58 +08:00 committed by GitHub
parent 7ce5768372
commit a2142dd93e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1283 additions and 46 deletions

View File

@ -22,6 +22,9 @@ jobs:
buildname: 'ubuntu-20.04/gcc'
triplet: x64-linux
compiler: gcc_64
- os: ubuntu-20.04
buildname: 'ubuntu-20.04/gcc-10'
triplet: x64-linux
- os: ubuntu-16.04
buildname: 'ubuntu-16.04/gcc'
triplet: x64-linux
@ -53,7 +56,10 @@ jobs:
# These aren't available or don't work well in vcpkg
sudo apt install libjsoncpp-dev uuid-dev openssl libssl-dev zlib1g-dev postgresql-all libsqlite3-dev
sudo apt install libbrotli-dev
- name: (Linux) Install gcc-10
if: matrix.buildname == 'ubuntu-20.04/gcc-10'
run: |
sudo apt install gcc-10 g++-10
- name: (Linux) Install boost
if: matrix.os == 'ubuntu-16.04'
run: |
@ -75,12 +81,24 @@ jobs:
# We'll use this as our working directory for all subsequent commands
shell: bash
working-directory: ${{env.GITHUB_WORKSPACE}}
if: runner.os != 'macOS'
if: runner.os != 'macOS' && matrix.buildname != 'ubuntu-20.04/gcc-10'
run: |
mkdir build
cd build
cmake .. -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DBUILD_TESTING=on
- name: Create Build Environment & Configure Cmake (gcc-10)
# Some projects don't allow in-source building, so create a separate build directory
# We'll use this as our working directory for all subsequent commands
shell: bash
working-directory: ${{env.GITHUB_WORKSPACE}}
if: matrix.buildname == 'ubuntu-20.04/gcc-10'
run: |
mkdir build
cd build
cmake .. -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DBUILD_TESTING=on -DCMAKE_CXX_FLAGS="-fcoroutines"
env:
CC: gcc-10
CXX: g++-10
- name: Create Build Environment & Configure Cmake (MacOS)
# Some projects don't allow in-source building, so create a separate build directory
# We'll use this as our working directory for all subsequent commands

View File

@ -64,11 +64,14 @@ include(CheckIncludeFileCXX)
check_include_file_cxx(any HAS_ANY)
check_include_file_cxx(string_view HAS_STRING_VIEW)
if(HAS_ANY AND HAS_STRING_VIEW)
check_include_file_cxx(coroutine HAS_COROUTINE)
if(HAS_ANY AND HAS_STRING_VIEW AND HAS_COROUTINE)
set(DROGON_CXX_STANDARD 20)
elseif(HAS_ANY AND HAS_STRING_VIEW)
set(DROGON_CXX_STANDARD 17)
else(HAS_ANY AND HAS_STRING_VIEW)
else()
set(DROGON_CXX_STANDARD 14)
endif(HAS_ANY AND HAS_STRING_VIEW)
endif()
target_include_directories(
${PROJECT_NAME}
@ -96,16 +99,18 @@ else(NOT WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE shlwapi)
endif(NOT WIN32)
if(DROGON_CXX_STANDARD LESS 17)
if(DROGON_CXX_STANDARD EQUAL 14)
# With C++14, use boost to support any and string_view
message(STATUS "use c++14")
find_package(Boost 1.61.0 REQUIRED)
message(STATUS "boost include dir:" ${Boost_INCLUDE_DIR})
target_link_libraries(${PROJECT_NAME} PUBLIC Boost::boost)
list(APPEND INCLUDE_DIRS_FOR_DYNAMIC_VIEW ${Boost_INCLUDE_DIR})
else(DROGON_CXX_STANDARD LESS 17)
elseif(DROGON_CXX_STANDARD EQUAL 17)
message(STATUS "use c++17")
endif(DROGON_CXX_STANDARD LESS 17)
else()
message(STATUS "use c++20")
endif()
set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake_modules/)
@ -246,6 +251,12 @@ execute_process(COMMAND "git" rev-parse HEAD
configure_file("${PROJECT_SOURCE_DIR}/cmake/templates/version.h.in"
"${PROJECT_SOURCE_DIR}/lib/inc/drogon/version.h" @ONLY)
if(DROGON_CXX_STANDARD EQUAL 20)
option(USE_COROUTINE "Enable C++20 coroutine support" ON)
else(DROGON_CXX_STANDARD EQUAL 20)
option(USE_COROUTINE "Enable C++20 coroutine support" OFF)
endif(DROGON_CXX_STANDARD EQUAL 20)
if(BUILD_EXAMPLES)
add_subdirectory(examples)
endif(BUILD_EXAMPLES)
@ -427,6 +438,7 @@ set(DROGON_UTIL_HEADERS
lib/inc/drogon/utils/any.h
lib/inc/drogon/utils/string_view.h
lib/inc/drogon/utils/optional.h
lib/inc/drogon/utils/coroutine.h
lib/inc/drogon/utils/HttpConstraint.h
lib/inc/drogon/utils/OStringStream.h)
install(FILES ${DROGON_UTIL_HEADERS}

View File

@ -35,6 +35,7 @@ Drogon is a cross-platform framework, It supports Linux, macOS, FreeBSD, OpenBSD
* Provide a convenient lightweight ORM implementation that supports for regular object-to-database bidirectional mapping;
* Support plugins which can be installed by the configuration file at load time;
* Support AOP with build-in joinpoints.
* Support C++ coroutins
## A very simple example

View File

@ -36,6 +36,7 @@ Drogon是一个跨平台框架它支持Linux也支持macOS、FreeBSDOpe
* 方便的轻量级ORM实现支持常规的对象到数据库的双向映射操作
* 支持插件,可通过配置文件在加载期动态拆装;
* 支持内建插入点的AOP
* 支持C++协程
## 一个非常简单的例子

View File

@ -36,6 +36,7 @@ Drogon是一個跨平台框架它支援Linux也支援macOS、FreeBSD/OpenB
* 方便的輕量級ORM實現一般物件到資料庫的雙向映射
* 支援外掛,可通過設定文件在載入時動態載入;
* 支援內建插入點的AOP
* 支援C++ coroutine
## 一個非常簡單的例子

View File

@ -18,6 +18,13 @@ set(simple_example_sources
simple_example/DigestAuthFilter.cc
simple_example/main.cc)
if(DROGON_CXX_STANDARD GREATER_EQUAL 20)
set(simple_example_sources ${simple_example_sources}
simple_example/api_v1_CoroTest.cc)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
endif(DROGON_CXX_STANDARD GREATER_EQUAL 20)
add_executable(webapp ${simple_example_sources})
drogon_create_views(webapp ${CMAKE_CURRENT_SOURCE_DIR}/simple_example
${CMAKE_CURRENT_BINARY_DIR})
@ -61,6 +68,10 @@ set(example_targets
pipelining_test
websocket_test
multiple_ws_test)
if(DROGON_CXX_STANDARD GREATER_EQUAL 20)
add_executable(websocket_coro_test simple_example_test/WebSocketCoroTest.cc)
set(simple_example ${simple_example} websocket_coro_test)
endif(DROGON_CXX_STANDARD GREATER_EQUAL 20)
set_property(TARGET ${example_targets}
PROPERTY CXX_STANDARD ${DROGON_CXX_STANDARD})

View File

@ -0,0 +1,18 @@
#include "api_v1_CoroTest.h"
using namespace api::v1;
Task<> CoroTest::get(HttpRequestPtr req,
std::function<void(const HttpResponsePtr &)> callback)
{
auto resp = HttpResponse::newHttpResponse();
resp->setBody("DEADBEEF");
callback(resp);
co_return;
}
Task<HttpResponsePtr> CoroTest::get2(HttpRequestPtr req)
{
auto resp = HttpResponse::newHttpResponse();
resp->setBody("BADDBEEF");
co_return resp;
}

View File

@ -0,0 +1,21 @@
#pragma once
#include <drogon/HttpController.h>
using namespace drogon;
namespace api
{
namespace v1
{
class CoroTest : public drogon::HttpController<CoroTest>
{
public:
METHOD_LIST_BEGIN
METHOD_ADD(CoroTest::get, "/get", Get);
METHOD_ADD(CoroTest::get2, "/get2", Get);
METHOD_LIST_END
Task<> get(HttpRequestPtr req,
std::function<void(const HttpResponsePtr &)> callback);
Task<HttpResponsePtr> get2(HttpRequestPtr req);
};
} // namespace v1
} // namespace api

View File

@ -0,0 +1,80 @@
#include <drogon/WebSocketClient.h>
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <iostream>
using namespace drogon;
using namespace std::chrono_literals;
Task<> doTest(WebSocketClientPtr wsPtr, HttpRequestPtr req, bool continually)
{
wsPtr->setAsyncMessageHandler(
[continually](std::string&& message,
const WebSocketClientPtr wsPtr,
const WebSocketMessageType type) -> Task<> {
std::cout << "new message:" << message << std::endl;
if (type == WebSocketMessageType::Pong)
{
std::cout << "recv a pong" << std::endl;
if (!continually)
{
app().getLoop()->quit();
}
}
co_return;
});
wsPtr->setAsyncConnectionClosedHandler(
[](const WebSocketClientPtr wsPtr) -> Task<> {
std::cout << "ws closed!" << std::endl;
co_return;
});
try
{
auto resp = co_await wsPtr->connectToServerCoro(req);
}
catch (...)
{
std::cout << "ws failed!" << std::endl;
if (!continually)
{
exit(1);
}
}
std::cout << "ws connected!" << std::endl;
wsPtr->getConnection()->setPingMessage("", 2s);
wsPtr->getConnection()->send("hello!");
}
int main(int argc, char* argv[])
{
auto wsPtr = WebSocketClient::newWebSocketClient("127.0.0.1", 8848);
auto req = HttpRequest::newHttpRequest();
req->setPath("/chat");
bool continually = true;
if (argc > 1)
{
if (std::string(argv[1]) == "-t")
continually = false;
else if (std::string(argv[1]) == "-p")
{
// Connect to a public web socket server.
wsPtr =
WebSocketClient::newWebSocketClient("wss://echo.websocket.org");
req->setPath("/");
}
}
app().getLoop()->runAfter(5.0, [continually]() {
if (!continually)
{
exit(1);
}
});
app().setLogLevel(trantor::Logger::kTrace);
[=]() -> AsyncTask { co_await doTest(wsPtr, req, continually); }();
app().run();
}

View File

@ -1303,6 +1303,54 @@ void doTest(const HttpClientPtr &client,
exit(1);
}
});
#ifdef __cpp_impl_coroutine
// Test coroutine requests
[client, isHttps]() -> AsyncTask {
try
{
auto req = HttpRequest::newHttpRequest();
req->setPath("/api/v1/corotest/get");
auto resp = co_await client->sendRequestCoro(req);
if (resp->getBody() != "DEADBEEF")
{
LOG_ERROR << resp->getBody();
LOG_ERROR << "Error!";
exit(1);
}
outputGood(req, isHttps);
}
catch (const std::exception &e)
{
LOG_DEBUG << e.what();
LOG_ERROR << "Error!";
exit(1);
}
}();
// Test coroutine request with co_return
[client, isHttps]() -> AsyncTask {
try
{
auto req = HttpRequest::newHttpRequest();
req->setPath("/api/v1/corotest/get2");
auto resp = co_await client->sendRequestCoro(req);
if (resp->getBody() != "BADDBEEF")
{
LOG_ERROR << resp->getBody();
LOG_ERROR << "Error!";
exit(1);
}
outputGood(req, isHttps);
}
catch (const std::exception &e)
{
LOG_DEBUG << e.what();
LOG_ERROR << "Error!";
exit(1);
}
}();
#endif
}
void loadFileLengths()
{

View File

@ -263,28 +263,79 @@ class HttpBinder : public HttpBinderBase
std::forward<Values>(values)...,
std::move(value));
}
template <typename... Values, std::size_t Boundary = argument_count>
typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run(
std::deque<std::string> &,
template <typename... Values,
std::size_t Boundary = argument_count,
bool isCoroutine = traits::isCoroutine>
typename std::enable_if<(sizeof...(Values) == Boundary) && !isCoroutine,
void>::type
run(std::deque<std::string> &,
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
{
callFunction(req, std::move(callback), std::move(values)...);
}
#ifdef __cpp_impl_coroutine
template <typename... Values,
std::size_t Boundary = argument_count,
bool isCoroutine = traits::isCoroutine>
typename std::enable_if<(sizeof...(Values) == Boundary) && isCoroutine,
void>::type
run(std::deque<std::string> &,
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
{
[this](HttpRequestPtr req,
std::function<void(const HttpResponsePtr &)> callback,
Values &&... values) -> AsyncTask {
try
{
if constexpr (std::is_same_v<AsyncTask,
typename traits::return_type>)
{
callFunction(req,
std::move(callback),
std::move(values)...);
}
else if constexpr (std::is_same_v<Task<>,
typename traits::return_type>)
{
co_await callFunction(req,
std::move(callback),
std::move(values)...);
}
else if constexpr (std::is_same_v<Task<HttpResponsePtr>,
typename traits::return_type>)
{
auto resp =
co_await callFunction(req, std::move(values)...);
callback(std::move(resp));
}
}
catch (const std::exception &e)
{
LOG_ERROR << "Uncaught exception in " << req->path()
<< " what(): " << e.what();
}
catch (...)
{
LOG_ERROR << "Uncaught unknown exception in " << req->path();
}
}(req, std::move(callback), std::move(values)...);
}
#endif
template <typename... Values,
bool isClassFunction = traits::isClassFunction,
bool isDrObjectClass = traits::isDrObjectClass,
bool isNormal = std::is_same<typename traits::first_param_type,
HttpRequestPtr>::value>
typename std::enable_if<isClassFunction && !isDrObjectClass && isNormal,
void>::type
callFunction(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
typename traits::return_type>::type
callFunction(const HttpRequestPtr &req, Values &&... values)
{
static auto &obj = getControllerObj<typename traits::class_type>();
(obj.*func_)(req, std::move(callback), std::move(values)...);
return (obj.*func_)(req, std::move(values)...);
}
template <typename... Values,
bool isClassFunction = traits::isClassFunction,
@ -292,25 +343,22 @@ class HttpBinder : public HttpBinderBase
bool isNormal = std::is_same<typename traits::first_param_type,
HttpRequestPtr>::value>
typename std::enable_if<isClassFunction && isDrObjectClass && isNormal,
void>::type
callFunction(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
typename traits::return_type>::type
callFunction(const HttpRequestPtr &req, Values &&... values)
{
static auto objPtr =
DrClassMap::getSingleInstance<typename traits::class_type>();
(*objPtr.*func_)(req, std::move(callback), std::move(values)...);
return (*objPtr.*func_)(req, std::move(values)...);
}
template <typename... Values,
bool isClassFunction = traits::isClassFunction,
bool isNormal = std::is_same<typename traits::first_param_type,
HttpRequestPtr>::value>
typename std::enable_if<!isClassFunction && isNormal, void>::type
callFunction(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
typename std::enable_if<!isClassFunction && isNormal,
typename traits::return_type>::type
callFunction(const HttpRequestPtr &req, Values &&... values)
{
func_(req, std::move(callback), std::move(values)...);
return func_(req, std::move(values)...);
}
template <typename... Values,
@ -319,13 +367,11 @@ class HttpBinder : public HttpBinderBase
bool isNormal = std::is_same<typename traits::first_param_type,
HttpRequestPtr>::value>
typename std::enable_if<isClassFunction && !isDrObjectClass && !isNormal,
void>::type
callFunction(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
typename traits::return_type>::type
callFunction(const HttpRequestPtr &req, Values &&... values)
{
static auto &obj = getControllerObj<typename traits::class_type>();
(obj.*func_)((*req), std::move(callback), std::move(values)...);
return (obj.*func_)((*req), std::move(values)...);
}
template <typename... Values,
bool isClassFunction = traits::isClassFunction,
@ -333,25 +379,22 @@ class HttpBinder : public HttpBinderBase
bool isNormal = std::is_same<typename traits::first_param_type,
HttpRequestPtr>::value>
typename std::enable_if<isClassFunction && isDrObjectClass && !isNormal,
void>::type
callFunction(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
typename traits::return_type>::type
callFunction(const HttpRequestPtr &req, Values &&... values)
{
static auto objPtr =
DrClassMap::getSingleInstance<typename traits::class_type>();
(*objPtr.*func_)((*req), std::move(callback), std::move(values)...);
return (*objPtr.*func_)((*req), std::move(values)...);
}
template <typename... Values,
bool isClassFunction = traits::isClassFunction,
bool isNormal = std::is_same<typename traits::first_param_type,
HttpRequestPtr>::value>
typename std::enable_if<!isClassFunction && !isNormal, void>::type
callFunction(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
Values &&... values)
typename std::enable_if<!isClassFunction && !isNormal,
typename traits::return_type>::type
callFunction(const HttpRequestPtr &req, Values &&... values)
{
func_((*req), std::move(callback), std::move(values)...);
return func_((*req), std::move(values)...);
}
};

View File

@ -23,11 +23,36 @@
#include <functional>
#include <memory>
#include <future>
#include "drogon/HttpBinder.h"
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
namespace drogon
{
class HttpClient;
using HttpClientPtr = std::shared_ptr<HttpClient>;
#ifdef __cpp_impl_coroutine
namespace internal
{
struct HttpRespAwaiter : public CallbackAwaiter<HttpResponsePtr>
{
HttpRespAwaiter(HttpClient *client, HttpRequestPtr req, double timeout)
: client_(client), req_(std::move(req)), timeout_(timeout)
{
}
void await_suspend(std::coroutine_handle<> handle);
private:
HttpClient *client_;
HttpRequestPtr req_;
double timeout_;
};
} // namespace internal
#endif
/// Asynchronous http client
/**
@ -115,6 +140,27 @@ class HttpClient : public trantor::NonCopyable
return f.get();
}
#ifdef __cpp_impl_coroutine
/**
* @brief Send a request via coroutines to the server and return the
* response.
*
* @param req
* @param timeout In seconds. If the response is not received within the
* timeout, the `ReqResult::Timeout` and an empty response is returned. The
* zero value by default disables the timeout.
*
* @return task<HttpResponsePtr>
*/
Task<HttpResponsePtr> sendRequestCoro(HttpRequestPtr req,
double timeout = 0)
{
co_return co_await internal::HttpRespAwaiter(this,
std::move(req),
timeout);
}
#endif
/// Set the pipelining depth, which is the number of requests that are not
/// responding.
/**
@ -219,4 +265,34 @@ class HttpClient : public trantor::NonCopyable
HttpClient() = default;
};
#ifdef __cpp_impl_coroutine
inline void internal::HttpRespAwaiter::await_suspend(
std::coroutine_handle<> handle)
{
client_->sendRequest(
req_,
[handle = std::move(handle), this](ReqResult result,
const HttpResponsePtr &resp) {
if (result == ReqResult::Ok)
setValue(resp);
else
{
std::string reason;
if (result == ReqResult::BadResponse)
reason = "BadResponse";
else if (result == ReqResult::NetworkFailure)
reason = "NetworkFailure";
else if (result == ReqResult::BadServerAddress)
reason = "BadServerAddress";
else if (result == ReqResult::Timeout)
reason = "Timeout";
setException(
std::make_exception_ptr(std::runtime_error(reason)));
}
handle.resume();
},
timeout_);
}
#endif
} // namespace drogon

View File

@ -18,6 +18,9 @@
#include <drogon/HttpResponse.h>
#include <drogon/WebSocketConnection.h>
#include <drogon/HttpTypes.h>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
#include <functional>
#include <memory>
#include <string>
@ -30,6 +33,26 @@ using WebSocketClientPtr = std::shared_ptr<WebSocketClient>;
using WebSocketRequestCallback = std::function<
void(ReqResult, const HttpResponsePtr &, const WebSocketClientPtr &)>;
#ifdef __cpp_impl_coroutine
namespace internal
{
struct WebSocketConnectionAwaiter : public CallbackAwaiter<HttpResponsePtr>
{
WebSocketConnectionAwaiter(WebSocketClient *client, HttpRequestPtr req)
: client_(client), req_(std::move(req))
{
}
void await_suspend(std::coroutine_handle<> handle);
private:
WebSocketClient *client_;
HttpRequestPtr req_;
};
} // namespace internal
#endif
/**
* @brief WebSocket client abstract class
*
@ -67,6 +90,55 @@ class WebSocketClient
virtual void connectToServer(const HttpRequestPtr &request,
const WebSocketRequestCallback &callback) = 0;
#ifdef __cpp_impl_coroutine
/**
* @brief Set messages handler. When a message is recieved from the server,
* the callback is called.
*
* @param callback The function to call when a message is received.
*/
void setAsyncMessageHandler(
const std::function<Task<>(std::string &&message,
const WebSocketClientPtr &,
const WebSocketMessageType &)> &callback)
{
setMessageHandler([callback](std::string &&message,
const WebSocketClientPtr &client,
const WebSocketMessageType &type) -> void {
[callback](std::string &&message,
const WebSocketClientPtr client,
const WebSocketMessageType type) -> AsyncTask {
co_await callback(std::move(message), client, type);
}(std::move(message), client, type);
});
}
/// Set the connection closing handler. When the connection is established
/// or closed, the @param callback is called with a bool parameter.
/**
* @brief Set the connection closing handler. When the websocket connection
* is closed, the callback is called
*
* @param callback The function to call when the connection is closed.
*/
void setAsyncConnectionClosedHandler(
const std::function<Task<>(const WebSocketClientPtr &)> &callback)
{
setConnectionClosedHandler(
[callback](const WebSocketClientPtr &client) {
[=]() -> AsyncTask { co_await callback(client); }();
});
}
/// Connect to the server.
internal::WebSocketConnectionAwaiter connectToServerCoro(
const HttpRequestPtr &request)
{
return internal::WebSocketConnectionAwaiter(this, request);
}
#endif
/// Get the event loop of the client;
virtual trantor::EventLoop *getLoop() = 0;
@ -125,4 +197,35 @@ class WebSocketClient
}
};
#ifdef __cpp_impl_coroutine
inline void internal::WebSocketConnectionAwaiter::await_suspend(
std::coroutine_handle<> handle)
{
client_->connectToServer(req_,
[this, handle](ReqResult result,
const HttpResponsePtr &resp,
const WebSocketClientPtr &) {
if (result == ReqResult::Ok)
setValue(resp);
else
{
std::string reason;
if (result == ReqResult::BadResponse)
reason = "BadResponse";
else if (result ==
ReqResult::NetworkFailure)
reason = "NetworkFailure";
else if (result ==
ReqResult::BadServerAddress)
reason = "BadServerAddress";
else if (result == ReqResult::Timeout)
reason = "Timeout";
setException(std::make_exception_ptr(
std::runtime_error(reason)));
}
handle.resume();
});
}
#endif
} // namespace drogon

View File

@ -20,6 +20,10 @@
#include <tuple>
#include <type_traits>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
namespace drogon
{
class HttpRequest;
@ -84,8 +88,10 @@ struct FunctionTraits<
Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isHTTPFunction = true;
static const bool isCoroutine = false;
using class_type = void;
using first_param_type = HttpRequestPtr;
using return_type = ReturnType;
};
template <typename ReturnType, typename... Arguments>
@ -98,6 +104,44 @@ struct FunctionTraits<
using class_type = void;
};
#ifdef __cpp_impl_coroutine
template <typename... Arguments>
struct FunctionTraits<
AsyncTask (*)(HttpRequestPtr req,
std::function<void(const HttpResponsePtr &)> callback,
Arguments...)> : FunctionTraits<AsyncTask (*)(Arguments...)>
{
static const bool isHTTPFunction = true;
static const bool isCoroutine = true;
using class_type = void;
using first_param_type = HttpRequestPtr;
using return_type = AsyncTask;
};
template <typename... Arguments>
struct FunctionTraits<
Task<> (*)(HttpRequestPtr req,
std::function<void(const HttpResponsePtr &)> callback,
Arguments...)> : FunctionTraits<AsyncTask (*)(Arguments...)>
{
static const bool isHTTPFunction = true;
static const bool isCoroutine = true;
using class_type = void;
using first_param_type = HttpRequestPtr;
using return_type = Task<>;
};
template <typename... Arguments>
struct FunctionTraits<Task<HttpResponsePtr> (*)(HttpRequestPtr req,
Arguments...)>
: FunctionTraits<AsyncTask (*)(Arguments...)>
{
static const bool isHTTPFunction = true;
static const bool isCoroutine = true;
using class_type = void;
using first_param_type = HttpRequestPtr;
using return_type = Task<HttpResponsePtr>;
};
#endif
template <typename ReturnType, typename... Arguments>
struct FunctionTraits<
ReturnType (*)(HttpRequestPtr &&req,
@ -116,8 +160,10 @@ struct FunctionTraits<
Arguments...)> : FunctionTraits<ReturnType (*)(Arguments...)>
{
static const bool isHTTPFunction = true;
static const bool isCoroutine = false;
using class_type = void;
using first_param_type = T;
using return_type = ReturnType;
};
// normal function
@ -132,9 +178,11 @@ struct FunctionTraits<ReturnType (*)(Arguments...)>
static const std::size_t arity = sizeof...(Arguments);
using class_type = void;
using return_type = ReturnType;
static const bool isHTTPFunction = false;
static const bool isClassFunction = false;
static const bool isDrObjectClass = false;
static const bool isCoroutine = false;
static const std::string name()
{
return std::string("Normal or Static Function");

View File

@ -0,0 +1,486 @@
/**
*
* coroutine.h
* Martin Chang
*
* Copyright 2021, Martin Chang. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <algorithm>
#include <coroutine>
#include <exception>
#include <type_traits>
#include <condition_variable>
#include <atomic>
#include <future>
#include <cassert>
#include <drogon/utils/optional.h>
namespace drogon
{
namespace internal
{
template <typename T>
auto getAwaiterImpl(T &&value) noexcept(
noexcept(static_cast<T &&>(value).operator co_await()))
-> decltype(static_cast<T &&>(value).operator co_await())
{
return static_cast<T &&>(value).operator co_await();
}
template <typename T>
auto getAwaiterImpl(T &&value) noexcept(
noexcept(operator co_await(static_cast<T &&>(value))))
-> decltype(operator co_await(static_cast<T &&>(value)))
{
return operator co_await(static_cast<T &&>(value));
}
template <typename T>
auto getAwaiter(T &&value) noexcept(
noexcept(getAwaiterImpl(static_cast<T &&>(value))))
-> decltype(getAwaiterImpl(static_cast<T &&>(value)))
{
return getAwaiterImpl(static_cast<T &&>(value));
}
} // end namespace internal
template <typename T>
struct await_result
{
using awaiter_t = decltype(internal::getAwaiter(std::declval<T>()));
using type = decltype(std::declval<awaiter_t>().await_resume());
};
template <typename T>
using await_result_t = await_result<T>::type;
template <typename T, typename = std::void_t<>>
struct is_awaitable : std::false_type
{
};
template <typename T>
struct is_awaitable<
T,
std::void_t<decltype(internal::getAwaiter(std::declval<T>()))>>
: std::true_type
{
};
template <typename T>
constexpr bool is_awaitable_v = is_awaitable<T>::value;
template <typename T>
struct final_awiter
{
bool await_ready() noexcept
{
return false;
}
auto await_suspend(std::coroutine_handle<T> handle) noexcept
{
return handle.promise().continuation_;
}
void await_resume() noexcept
{
}
};
template <typename T = void>
struct Task
{
struct promise_type;
using handle_type = std::coroutine_handle<promise_type>;
Task(handle_type h) : coro_(h)
{
}
Task(const Task &) = delete;
Task(Task &&other)
{
coro_ = other.coro_;
other.coro_ = nullptr;
}
~Task()
{
if (coro_)
coro_.destroy();
}
Task &operator=(const Task &) = delete;
Task &operator=(Task &&other)
{
coro_ = other.coro_;
other.coro_ = nullptr;
return *this;
}
struct promise_type
{
Task<T> get_return_object()
{
return Task<T>{handle_type::from_promise(*this)};
}
std::suspend_always initial_suspend()
{
return {};
}
void return_value(const T &v)
{
value = v;
}
auto final_suspend() noexcept
{
return final_awiter<promise_type>{};
}
void unhandled_exception()
{
exception_ = std::current_exception();
}
const T &result() const
{
if (exception_ != nullptr)
std::rethrow_exception(exception_);
assert(value.has_value() == true);
return value.value();
}
void setContinuation(std::coroutine_handle<> handle)
{
continuation_ = handle;
}
optional<T> value;
std::exception_ptr exception_;
std::coroutine_handle<> continuation_;
};
bool await_ready() const
{
return coro_.done();
}
std::coroutine_handle<> await_suspend(std::coroutine_handle<> awaiting)
{
coro_.promise().setContinuation(awaiting);
return coro_;
}
auto operator co_await() const noexcept
{
struct awaiter
{
public:
explicit awaiter(handle_type coro) : coro_(coro)
{
}
bool await_ready() noexcept
{
return false;
}
auto await_suspend(std::coroutine_handle<> handle) noexcept
{
coro_.promise().setContinuation(handle);
return coro_;
}
T await_resume()
{
return coro_.promise().result();
}
private:
handle_type coro_;
};
return awaiter(coro_);
}
handle_type coro_;
};
template <>
struct Task<void>
{
struct promise_type;
using handle_type = std::coroutine_handle<promise_type>;
Task(handle_type handle) : coro_(handle)
{
}
Task(const Task &) = delete;
Task(Task &&other)
{
coro_ = other.coro_;
other.coro_ = nullptr;
}
~Task()
{
if (coro_)
coro_.destroy();
}
Task &operator=(const Task &) = delete;
Task &operator=(Task &&other)
{
coro_ = other.coro_;
other.coro_ = nullptr;
return *this;
}
struct promise_type
{
Task<> get_return_object()
{
return Task<>{handle_type::from_promise(*this)};
}
std::suspend_always initial_suspend()
{
return {};
}
void return_void()
{
}
void return_value()
{
}
auto final_suspend() noexcept
{
return final_awiter<promise_type>{};
}
void unhandled_exception()
{
exception_ = std::current_exception();
}
void result()
{
if (exception_ != nullptr)
std::rethrow_exception(exception_);
}
void setContinuation(std::coroutine_handle<> handle)
{
continuation_ = handle;
}
std::exception_ptr exception_;
std::coroutine_handle<> continuation_;
};
bool await_ready()
{
return coro_.done();
}
std::coroutine_handle<> await_suspend(std::coroutine_handle<> awaiting)
{
coro_.promise().setContinuation(awaiting);
return coro_;
}
auto operator co_await() const noexcept
{
struct awaiter
{
public:
explicit awaiter(handle_type coro) : coro_(coro)
{
}
bool await_ready() noexcept
{
return false;
}
auto await_suspend(std::coroutine_handle<> handle) noexcept
{
coro_.promise().setContinuation(handle);
return coro_;
}
void await_resume()
{
coro_.promise().result();
}
private:
handle_type coro_;
};
return awaiter(coro_);
}
handle_type coro_;
};
/// Fires a coroutine and doesn't force waiting nor deallocates upon promise
/// destructs
// NOTE: AsyncTask is designed to be not awaitable. And kills the entire process
// if exception escaped.
struct AsyncTask final
{
struct promise_type final
{
auto initial_suspend() noexcept
{
return std::suspend_never{};
}
auto final_suspend() noexcept
{
return std::suspend_never{};
}
void return_void() noexcept
{
}
void unhandled_exception()
{
std::terminate();
}
promise_type *get_return_object() noexcept
{
return this;
}
void result()
{
}
};
AsyncTask(const promise_type *) noexcept
{
// the type truncates all given info about its frame
}
};
/// Helper class that provices the infrastructure for turning callback into
/// corourines
// The user is responsible to fill in `await_suspend()` and construtors.
template <typename T>
struct CallbackAwaiter
{
bool await_ready() noexcept
{
return false;
}
const T &await_resume() noexcept(false)
{
// await_resume() should always be called after co_await
// (await_suspend()) is called. Therefor the value should always be set
// (or there's an exception)
assert(result_.has_value() == true || exception_ != nullptr);
if (exception_)
std::rethrow_exception(exception_);
return result_.value();
}
private:
// HACK: Not all desired types are default contructable. But we need the
// entire struct to be constructed for awaiting. std::optional takes care of
// that.
optional<T> result_;
std::exception_ptr exception_ = nullptr;
protected:
void setException(const std::exception_ptr &e)
{
exception_ = e;
}
void setValue(const T &v)
{
result_.emplace(v);
}
void setValue(T &&v)
{
result_.emplace(std::move(v));
}
};
// An ok implementation of sync_await. This allows one to call
// coroutines and wait for the result from a function.
//
// NOTE: Not sure if this is a compiler bug. But causes use after free in some
// cases. Don't use it in production code.
template <typename AWAIT>
auto sync_wait(AWAIT &&await)
{
using value_type = typename await_result<AWAIT>::type;
std::condition_variable cv;
std::mutex mtx;
std::atomic<bool> flag = false;
std::exception_ptr exception_ptr;
if constexpr (std::is_same_v<value_type, void>)
{
[&, lk = std::unique_lock(mtx)]() -> AsyncTask {
try
{
co_await await;
}
catch (...)
{
exception_ptr = std::current_exception();
}
flag = true;
cv.notify_one();
}();
std::unique_lock lk(mtx);
cv.wait(lk, [&]() { return (bool)flag; });
if (exception_ptr)
std::rethrow_exception(exception_ptr);
}
else
{
optional<value_type> value;
[&, lk = std::unique_lock(mtx)]() -> AsyncTask {
try
{
value = co_await await;
}
catch (const std::exception &e)
{
exception_ptr = std::current_exception();
}
flag = true;
}();
std::unique_lock lk(mtx);
cv.wait(lk, [&]() { return (bool)flag; });
assert(value.has_value() == true || exception_ptr);
if (exception_ptr)
std::rethrow_exception(exception_ptr);
return value.value();
}
}
// Converts a task (or task like) promise into std::future for old-style async
// NOTE: Not sure if this is a compiler bug. But causes use after free in some
// cases. Don't use it in production code.
template <typename Await>
inline auto co_future(Await await) noexcept
-> std::future<await_result_t<Await>>
{
using Result = await_result_t<Await>;
std::promise<Result> prom;
auto fut = prom.get_future();
[](std::promise<Result> &&prom, Await &&await) -> AsyncTask {
try
{
if constexpr (std::is_void_v<Result>)
{
co_await std::move(await);
prom.set_value();
}
else
prom.set_value(co_await std::move(await));
}
catch (...)
{
prom.set_exception(std::current_exception());
}
}(std::move(prom), std::move(await));
return fut;
}
} // namespace drogon

View File

@ -22,6 +22,10 @@ set(test_targets
url_codec_test
main_loop_test
main_loop_test2)
if(DROGON_CXX_STANDARD GREATER_EQUAL 20)
add_executable(coroutine_test CoroutineTest.cc)
set(test_targets ${test_targets} coroutine_test)
endif(DROGON_CXX_STANDARD GREATER_EQUAL 20)
set_property(TARGET ${test_targets}
PROPERTY CXX_STANDARD ${DROGON_CXX_STANDARD})

View File

@ -0,0 +1,64 @@
#include <drogon/utils/coroutine.h>
#include <exception>
#include <type_traits>
#include <iostream>
using namespace drogon;
int main()
{
// Basic checks making sure coroutine works as expected
static_assert(is_awaitable_v<Task<>>);
static_assert(is_awaitable_v<Task<int>>);
static_assert(std::is_same_v<await_result_t<Task<int>>, int>);
static_assert(std::is_same_v<await_result_t<Task<>>, void>);
static_assert(is_awaitable_v<Task<>>);
static_assert(is_awaitable_v<Task<int>>);
// No, you cannot await AsyncTask. By design
static_assert(is_awaitable_v<AsyncTask> == false);
// Make sure sync_wait works
if (sync_wait([]() -> Task<int> { co_return 1; }()) != 1)
{
std::cerr << "Expected coroutine return 1. Didn't get that\n";
exit(1);
}
// co_future converts coroutine into futures
auto fut = co_future([]() -> Task<std::string> { co_return "zxc"; }());
if (fut.get() != "zxc")
{
std::cerr << "Expected future return \'zxc\'. Didn't get that\n";
exit(1);
}
// Testing that exceptions can propergate through coroutines
auto throw_in_task = []() -> Task<> {
auto f = []() -> Task<> { throw std::runtime_error("test error"); };
try
{
f();
std::cerr << "Exception should have been thrown\n";
exit(1);
}
catch (const std::exception& e)
{
if (std::string(e.what()) != "test error")
{
std::cerr << "Not the right exception\n";
exit(1);
}
}
catch (...)
{
std::cerr << "Shouldn't reach here\n";
exit(1);
}
co_return;
};
sync_wait(throw_in_task());
std::cout << "Done testing coroutines. No error." << std::endl;
}

View File

@ -27,6 +27,10 @@
#include <trantor/utils/Logger.h>
#include <trantor/utils/NonCopyable.h>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
namespace drogon
{
namespace orm
@ -35,6 +39,49 @@ using ResultCallback = std::function<void(const Result &)>;
using ExceptionCallback = std::function<void(const DrogonDbException &)>;
class Transaction;
class DbClient;
namespace internal
{
#ifdef __cpp_impl_coroutine
struct SqlAwaiter : public CallbackAwaiter<Result>
{
SqlAwaiter(internal::SqlBinder &&binder) : binder_(binder)
{
}
void await_suspend(std::coroutine_handle<> handle)
{
binder_ >> [handle, this](const drogon::orm::Result &result) {
setValue(result);
handle.resume();
};
binder_ >> [handle, this](const std::exception_ptr &e) {
setException(e);
handle.resume();
};
binder_.exec();
}
private:
internal::SqlBinder binder_;
};
struct TrasactionAwaiter : public CallbackAwaiter<std::shared_ptr<Transaction>>
{
TrasactionAwaiter(DbClient *client) : client_(client)
{
}
void await_suspend(std::coroutine_handle<> handle);
private:
DbClient *client_;
};
#endif
} // namespace internal
/// Database client abstract class
class DbClient : public trantor::NonCopyable
@ -150,6 +197,18 @@ class DbClient : public trantor::NonCopyable
return r;
}
#ifdef __cpp_impl_coroutine
template <typename... Arguments>
const Task<Result> execSqlCoro(const std::string sql,
Arguments... args) noexcept
{
auto binder = *this << sql;
(void)std::initializer_list<int>{
(binder << std::forward<Arguments>(args), 0)...};
co_return co_await internal::SqlAwaiter(std::move(binder));
}
#endif
/// Streaming-like method for sql execution. For more information, see the
/// wiki page.
internal::SqlBinder operator<<(const std::string &sql);
@ -183,6 +242,14 @@ class DbClient : public trantor::NonCopyable
virtual void newTransactionAsync(
const std::function<void(const std::shared_ptr<Transaction> &)>
&callback) = 0;
#ifdef __cpp_impl_coroutine
Task<std::shared_ptr<Transaction>> newTransactionCoro()
{
co_return co_await orm::internal::TrasactionAwaiter(this);
}
#endif
/**
* @brief Check if there is a connection successfully established.
*
@ -227,5 +294,22 @@ class Transaction : public DbClient
const std::function<void(bool)> &commitCallback) = 0;
};
#ifdef __cpp_impl_coroutine
inline void internal::TrasactionAwaiter::await_suspend(
std::coroutine_handle<> handle)
{
assert(client_ != nullptr);
client_->newTransactionAsync(
[this, handle](const std::shared_ptr<Transaction> transacton) {
if (transacton == nullptr)
setException(std::make_exception_ptr(
std::runtime_error("Failed to create transaction")));
else
setValue(transacton);
handle.resume();
});
}
#endif
} // namespace orm
} // namespace drogon

View File

@ -18,14 +18,14 @@
using namespace drogon::orm;
using namespace drogon;
internal::SqlBinder DbClient::operator<<(const std::string &sql)
orm::internal::SqlBinder DbClient::operator<<(const std::string &sql)
{
return internal::SqlBinder(sql, *this, type_);
return orm::internal::SqlBinder(sql, *this, type_);
}
internal::SqlBinder DbClient::operator<<(std::string &&sql)
orm::internal::SqlBinder DbClient::operator<<(std::string &&sql)
{
return internal::SqlBinder(std::move(sql), *this, type_);
return orm::internal::SqlBinder(std::move(sql), *this, type_);
}
std::shared_ptr<DbClient> DbClient::newPgClient(const std::string &connInfo,

View File

@ -31,9 +31,15 @@ using namespace drogon::orm;
#define RED "\033[31m" /* Red */
#define GREEN "\033[32m" /* Green */
#ifdef __cpp_impl_coroutine
constexpr int postgre_tests = 46;
constexpr int mysql_tests = 47;
constexpr int sqlite_tests = 49;
#else
constexpr int postgre_tests = 44;
constexpr int mysql_tests = 45;
constexpr int sqlite_tests = 47;
#endif
int test_count = 0;
int counter = 0;
@ -723,6 +729,40 @@ void doPostgreTest(const drogon::orm::DbClientPtr &clientPtr)
std::cerr << e.base().what() << std::endl;
testOutput(false, "postgresql - ORM mapper synchronous interface(0)");
}
#ifdef __cpp_impl_coroutine
auto coro_test = [clientPtr]() -> drogon::Task<> {
/// 7 Test coroutines.
/// This is by no means comprehensive. But coroutine API is esentially a
/// wrapper arround callbacks. The purpose is to test the interface
/// works 7.1 Basic queries
try
{
auto result =
co_await clientPtr->execSqlCoro("select * from users;");
testOutput(result.size() != 0,
"postgresql - DbClient coroutine interface(0)");
}
catch (const Failure &e)
{
std::cerr << e.what() << std::endl;
testOutput(false, "postgresql - DbClient coroutine interface(0)");
}
/// 7.2 Parameter binding
try
{
auto result = co_await clientPtr->execSqlCoro(
"select * from users where 1=$1;", 1);
testOutput(result.size() != 0,
"postgresql - DbClient coroutine interface(1)");
}
catch (const Failure &e)
{
std::cerr << e.what() << std::endl;
testOutput(false, "postgresql - DbClient coroutine interface(1)");
}
};
drogon::sync_wait(coro_test());
#endif
}
void doMysqlTest(const drogon::orm::DbClientPtr &clientPtr)
@ -1341,6 +1381,40 @@ void doMysqlTest(const drogon::orm::DbClientPtr &clientPtr)
std::cerr << e.base().what() << std::endl;
testOutput(false, "mysql - ORM mapper synchronous interface(0)");
}
#ifdef __cpp_impl_coroutine
auto coro_test = [clientPtr]() -> drogon::Task<> {
/// 7 Test coroutines.
/// This is by no means comprehensive. But coroutine API is esentially a
/// wrapper arround callbacks. The purpose is to test the interface
/// works 7.1 Basic queries
try
{
auto result =
co_await clientPtr->execSqlCoro("select * from users;");
testOutput(result.size() != 0,
"postgresql - DbClient coroutine interface(0)");
}
catch (const Failure &e)
{
std::cerr << e.what() << std::endl;
testOutput(false, "postgresql - DbClient coroutine interface(0)");
}
/// 7.2 Parameter binding
try
{
auto result = co_await clientPtr->execSqlCoro(
"select * from users where 1=?;", 1);
testOutput(result.size() != 0,
"postgresql - DbClient coroutine interface(1)");
}
catch (const Failure &e)
{
std::cerr << e.what() << std::endl;
testOutput(false, "postgresql - DbClient coroutine interface(1)");
}
};
drogon::sync_wait(coro_test());
#endif
}
void doSqliteTest(const drogon::orm::DbClientPtr &clientPtr)
@ -1987,6 +2061,40 @@ void doSqliteTest(const drogon::orm::DbClientPtr &clientPtr)
std::cerr << e.base().what() << std::endl;
testOutput(false, "sqlite3 - ORM mapper synchronous interface(0)");
}
#ifdef __cpp_impl_coroutine
auto coro_test = [clientPtr]() -> drogon::Task<> {
/// 7 Test coroutines.
/// This is by no means comprehensive. But coroutine API is esentially a
/// wrapper arround callbacks. The purpose is to test the interface
/// works 7.1 Basic queries
try
{
auto result =
co_await clientPtr->execSqlCoro("select * from users;");
testOutput(result.size() != 0,
"postgresql - DbClient coroutine interface(0)");
}
catch (const Failure &e)
{
std::cerr << e.what() << std::endl;
testOutput(false, "postgresql - DbClient coroutine interface(0)");
}
/// 7.2 Parameter binding
try
{
auto result = co_await clientPtr->execSqlCoro(
"select * from users where 1=?;", 1);
testOutput(result.size() != 0,
"postgresql - DbClient coroutine interface(1)");
}
catch (const Failure &e)
{
std::cerr << e.what() << std::endl;
testOutput(false, "postgresql - DbClient coroutine interface(1)");
}
};
drogon::sync_wait(coro_test());
#endif
}
int main(int argc, char *argv[])

10
test.sh
View File

@ -89,6 +89,16 @@ if [ $? -ne 0 ]; then
exit -1
fi
# Test websocket client coroutine
if [ -f ./websocket_coro_test ]; then
echo "Test WebSocket w/ coroutine"
./websocket_coro_test -t
if [ $? -ne 0 ]; then
echo "Error in testing WebSocket with coroutine"
exit -1
fi
fi
#Test pipelining
echo "Test the pipelining"
./pipelining_test