Add WebSocket client support

This commit is contained in:
antao 2019-04-06 23:06:38 +08:00
parent 62f268d837
commit b96eb04859
28 changed files with 932 additions and 179 deletions

View File

@ -22,7 +22,7 @@ Drogon's main application platform is Linux. It also supports Mac OS and FreeBSD
* Provide a convenient and flexible routing solution from the path to the controller handler;
* Support filter chains to facilitate the execution of unified logic (such as login verification, Http Method constraint verification, etc.) before controllers;
* Support https (based on OpenSSL);
* Support WebSocket (server side);
* Support WebSocket (server side and client side);
* Support JSON format request and response, very friendly to the Restful API application development;
* Support file download and upload;
* Support gzip compression transmission;

View File

@ -22,7 +22,7 @@ Drogon的主要应用平台是Linux也支持Mac OS、FreeBSD目前还不
* 非常方便灵活的路径(path)到控制器处理函数(handler)的映射方案;
* 支持过滤器(filter)链,方便在控制器之前执行统一的逻辑(如登录验证、Http Method约束验证等)
* 支持https(基于OpenSSL实现);
* 支持websocket(server端);
* 支持websocket(server端和client端);
* 支持Json格式请求和应答, 对Restful API应用开发非常友好;
* 支持文件下载和上传,支持sendfile系统调用
* 支持gzip压缩传输

View File

@ -132,7 +132,8 @@ void create_controller::newWebsockControllerHeaderFile(std::ofstream &file, cons
file << indent << "{\n";
file << indent << "public:\n";
file << indent << " virtual void handleNewMessage(const WebSocketConnectionPtr&,\n";
file << indent << " std::string &&)override;\n";
file << indent << " std::string &&,\n";
file << indent << " const WebSocketMessageType &) override;\n";
file << indent << " virtual void handleNewConnection(const HttpRequestPtr &,\n";
file << indent << " const WebSocketConnectionPtr&)override;\n";
file << indent << " virtual void handleConnectionClosed(const WebSocketConnectionPtr&)override;\n";
@ -160,7 +161,7 @@ void create_controller::newWebsockControllerSourceFile(std::ofstream &file, cons
file << "using namespace " << namespacename << ";\n";
class_name = className.substr(pos + 2);
}
file << "void " << class_name << "::handleNewMessage(const WebSocketConnectionPtr& wsConnPtr,std::string &&message)\n";
file << "void " << class_name << "::handleNewMessage(const WebSocketConnectionPtr& wsConnPtr, std::string &&message, const WebSocketMessageType &type)\n";
file << "{\n";
file << " //write your application logic here\n";
file << "}\n";

View File

@ -27,6 +27,7 @@ add_executable(client ${DIR_CLIENT})
add_executable(benchmark ${DIR_BENCHMARK})
add_executable(webapp_test simple_example_test/main.cc)
add_executable(pipeline_test simple_example_test/HttpPipelineTest.cc)
add_executable(websocket_test simple_example_test/WebSocketTest.cc)
add_custom_command(TARGET webapp POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different

View File

@ -1,14 +1,20 @@
#include "WebSocketTest.h"
using namespace example;
void WebSocketTest::handleNewMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message)
void WebSocketTest::handleNewMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type)
{
//write your application logic here
LOG_DEBUG << "new websocket message:" << message;
if (type == WebSocketMessageType::Ping)
{
LOG_DEBUG << "recv a ping";
}
}
void WebSocketTest::handleConnectionClosed(const WebSocketConnectionPtr &)
{
LOG_DEBUG << "websocket closed!";
}
void WebSocketTest::handleNewConnection(const HttpRequestPtr &,
const WebSocketConnectionPtr &conn)
{

View File

@ -7,7 +7,8 @@ class WebSocketTest : public drogon::WebSocketController<WebSocketTest>
{
public:
virtual void handleNewMessage(const WebSocketConnectionPtr &,
std::string &&) override;
std::string &&,
const WebSocketMessageType &) override;
virtual void handleConnectionClosed(const WebSocketConnectionPtr &) override;
virtual void handleNewConnection(const HttpRequestPtr &, const WebSocketConnectionPtr &) override;
WS_PATH_LIST_BEGIN

View File

@ -0,0 +1,57 @@
#include <drogon/WebSocketClient.h>
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <iostream>
using namespace drogon;
int main(int argc, char *argv[])
{
auto wsPtr = WebSocketClient::newWebSocketClient("127.0.0.1", 8848);
auto req = HttpRequest::newHttpRequest();
bool continually = true;
if (argc > 1)
{
if (std::string(argv[1]) == "-t")
continually = false;
}
req->setPath("/chat");
wsPtr->setMessageHandler([continually](const std::string &message, const WebSocketClientPtr &wsPtr, const WebSocketMessageType &type) {
std::cout << "new message:" << message << std::endl;
if (type == WebSocketMessageType::Pong)
{
std::cout << "recv a pong" << std::endl;
if (!continually)
{
app().getLoop()->quit();
}
}
});
wsPtr->setConnectionClosedHandler([](const WebSocketClientPtr &wsPtr) {
std::cout << "ws closed!" << std::endl;
});
wsPtr->connectToServer(req, [continually](ReqResult r, const HttpResponsePtr &resp, const WebSocketClientPtr &wsPtr) {
if (r == ReqResult::Ok)
{
std::cout << "ws connected!" << std::endl;
wsPtr->getConnection()->send("hello");
}
else
{
std::cout << "ws failed!" << std::endl;
if (!continually)
{
exit(-1);
}
}
});
wsPtr->setHeartbeatMessage("", 1.0);
app().getLoop()->runAfter(5.0, [continually]() {
if (!continually)
{
exit(-1);
}
});
app().run();
}

View File

@ -16,6 +16,7 @@
#include <drogon/HttpRequest.h>
#include <drogon/HttpResponse.h>
#include <drogon/HttpTypes.h>
#include <trantor/utils/NonCopyable.h>
#include <trantor/net/EventLoop.h>
#include <functional>
@ -23,19 +24,11 @@
namespace drogon
{
enum class ReqResult
{
Ok,
BadResponse,
NetworkFailure,
BadServerAddress,
Timeout
};
typedef std::function<void(ReqResult, const HttpResponsePtr &response)> HttpReqCallback;
class HttpClient;
typedef std::shared_ptr<HttpClient> HttpClientPtr;
typedef std::function<void(ReqResult, const HttpResponsePtr &)> HttpReqCallback;
/// Asynchronous http client
/**
* HttpClient implementation object uses the HttpAppFramework's event loop by default,

View File

@ -70,6 +70,9 @@ class HttpResponse
virtual const std::string &getHeader(const std::string &key, const std::string &defaultVal = std::string()) const = 0;
virtual const std::string &getHeader(std::string &&key, const std::string &defaultVal = std::string()) const = 0;
/// Get all headers of the response
virtual const std::unordered_map<std::string, std::string> &headers() const = 0;
/// Add a header.
virtual void addHeader(const std::string &key, const std::string &value) = 0;
virtual void addHeader(const std::string &key, std::string &&value) = 0;

View File

@ -109,4 +109,13 @@ enum HttpMethod
Invalid
};
enum class ReqResult
{
Ok,
BadResponse,
NetworkFailure,
BadServerAddress,
Timeout
};
} // namespace drogon

View File

@ -0,0 +1,102 @@
/**
*
* WebSocketClient.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/HttpRequest.h>
#include <drogon/HttpResponse.h>
#include <drogon/WebSocketConnection.h>
#include <trantor/net/EventLoop.h>
#include <string>
#include <functional>
#include <memory>
namespace drogon
{
class WebSocketClient;
typedef std::shared_ptr<WebSocketClient> WebSocketClientPtr;
typedef std::function<void(ReqResult, const HttpResponsePtr &, const WebSocketClientPtr &)> WebSocketRequestCallback;
/// WebSocket client abstract class
class WebSocketClient
{
public:
/// Get the WebSocket connection that is typically used to send messages.
virtual const WebSocketConnectionPtr &getConnection() = 0;
/// Set messages handler. When a message is recieved from the server, the @param callback is called.
virtual void setMessageHandler(const std::function<void(std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &)> &callback) = 0;
/// Set the connection handler. When the connection is established or closed, the @param callback is called with a bool
/// parameter.
virtual void setConnectionClosedHandler(const std::function<void(const WebSocketClientPtr &)> &callback) = 0;
/// Set the heartbeat(ping) message sent to the server.
/**
* NOTE:
* Both the server and the client in Drogon automatically send the pong message after receiving the ping message.
*/
virtual void setHeartbeatMessage(const std::string &message, double interval) = 0;
/// Connect to the server.
virtual void connectToServer(const HttpRequestPtr &request, const WebSocketRequestCallback &callback) = 0;
/// Get the event loop of the client;
virtual trantor::EventLoop *getLoop() = 0;
/// Use ip and port to connect to server
/**
* If useSSL is set to true, the client
* connects to the server using SSL.
*
* If the loop parameter is set to nullptr, the client
* uses the HttpAppFramework's event loop, otherwise it
* runs in the loop identified by the parameter.
*
* Note: The @param ip support for both ipv4 and ipv6 address
*/
static WebSocketClientPtr newWebSocketClient(const std::string &ip,
uint16_t port,
bool useSSL = false,
trantor::EventLoop *loop = nullptr);
/// Use hostString to connect to server
/**
* Examples for hostString:
* wss://www.google.com
* ws://www.google.com
* wss://127.0.0.1:8080/
* ws://127.0.0.1
*
* The @param hostString must be prefixed by 'ws://' or 'wss://'
* and doesn't support for ipv6 address if the host is in ip format
*
* If the @param loop is set to nullptr, the client
* uses the HttpAppFramework's main event loop, otherwise it
* runs in the loop identified by the parameter.
*
* NOTE:
* Don't add path and parameters in hostString, the request path
* and parameters should be set in HttpRequestPtr when calling
* the connectToServer() method.
*
*/
static WebSocketClientPtr newWebSocketClient(const std::string &hostString,
trantor::EventLoop *loop = nullptr);
virtual ~WebSocketClient() {}
};
} // namespace drogon

View File

@ -21,13 +21,24 @@
#include <memory>
namespace drogon
{
enum class WebSocketMessageType
{
Text,
Binary,
Ping,
Pong,
Close,
Unknown
};
class WebSocketConnection
{
public:
WebSocketConnection() = default;
virtual ~WebSocketConnection(){};
virtual void send(const char *msg, uint64_t len) = 0;
virtual void send(const std::string &msg) = 0;
virtual void send(const char *msg, uint64_t len, const WebSocketMessageType &type = WebSocketMessageType::Text) = 0;
virtual void send(const std::string &msg, const WebSocketMessageType &type = WebSocketMessageType::Text) = 0;
virtual const trantor::InetAddress &localAddr() const = 0;
virtual const trantor::InetAddress &peerAddr() const = 0;

View File

@ -41,14 +41,18 @@ namespace drogon
class WebSocketControllerBase : public virtual DrObjectBase
{
public:
//on new data received
//Call this function when a new message is received
virtual void handleNewMessage(const WebSocketConnectionPtr &,
std::string &&) = 0;
//after new websocket connection established
std::string &&,
const WebSocketMessageType &) = 0;
//Call this function after the new connection of WebSocket is established.
virtual void handleNewConnection(const HttpRequestPtr &,
const WebSocketConnectionPtr &) = 0;
//after connection closed
//Call this function after the WebSocket connection is closed
virtual void handleConnectionClosed(const WebSocketConnectionPtr &) = 0;
virtual ~WebSocketControllerBase() {}
};

View File

@ -319,7 +319,7 @@ void HttpAppFrameworkImpl::run()
}
serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest, this, _1, _2));
serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest, this, _1, _2, _3));
serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2));
serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2, _3));
serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect, this, _1));
serverPtr->setConnectionCallback(std::bind(&HttpAppFrameworkImpl::onConnection, this, _1));
serverPtr->kickoffIdleConnections(_idleConnectionTimeout);
@ -360,7 +360,7 @@ void HttpAppFrameworkImpl::run()
serverPtr->setIoLoopNum(_threadNum);
serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest, this, _1, _2));
serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest, this, _1, _2, _3));
serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2));
serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2, _3));
serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect, this, _1));
serverPtr->setConnectionCallback(std::bind(&HttpAppFrameworkImpl::onConnection, this, _1));
serverPtr->kickoffIdleConnections(_idleConnectionTimeout);
@ -548,14 +548,14 @@ void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn)
}
}
void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message)
void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type)
{
auto wsConnImplPtr = std::dynamic_pointer_cast<WebSocketConnectionImpl>(wsConnPtr);
assert(wsConnImplPtr);
auto ctrl = wsConnImplPtr->controller();
if (ctrl)
{
ctrl->handleNewMessage(wsConnPtr, std::move(message));
ctrl->handleNewMessage(wsConnPtr, std::move(message), type);
}
}

View File

@ -178,7 +178,7 @@ class HttpAppFrameworkImpl : public HttpAppFramework
void onNewWebsockRequest(const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr);
void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message);
void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type);
void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr);
void onConnection(const TcpConnectionPtr &conn);
void addHttpPath(const std::string &path,

View File

@ -318,12 +318,12 @@ void HttpClientImpl::onRecvMessage(const trantor::TcpConnectionPtr &connPtr, tra
HttpClientPtr HttpClient::newHttpClient(const std::string &ip, uint16_t port, bool useSSL, trantor::EventLoop *loop)
{
bool isIpv6 = ip.find(":") == std::string::npos ? false : true;
return std::make_shared<HttpClientImpl>(loop == nullptr ? HttpAppFrameworkImpl::instance().getLoop() : loop, trantor::InetAddress(ip, port, isIpv6), useSSL);
return std::make_shared<HttpClientImpl>(loop == nullptr ? app().getLoop() : loop, trantor::InetAddress(ip, port, isIpv6), useSSL);
}
HttpClientPtr HttpClient::newHttpClient(const std::string &hostString, trantor::EventLoop *loop)
{
return std::make_shared<HttpClientImpl>(loop == nullptr ? HttpAppFrameworkImpl::instance().getLoop() : loop, hostString);
return std::make_shared<HttpClientImpl>(loop == nullptr ? app().getLoop() : loop, hostString);
}
void HttpClientImpl::onError(ReqResult result)

View File

@ -15,8 +15,8 @@
#pragma once
#include "HttpRequestImpl.h"
#include "WebSockectConnectionImpl.h"
#include <trantor/utils/MsgBuffer.h>
#include <drogon/WebSocketConnection.h>
#include <drogon/HttpResponse.h>
#include <deque>
#include <mutex>
@ -72,11 +72,11 @@ class HttpRequestParser
}
return false;
}
const WebSocketConnectionPtr &webSocketConn() const
const WebSocketConnectionImplPtr &webSocketConn() const
{
return _websockConnPtr;
}
void setWebsockConnection(const WebSocketConnectionPtr &conn)
void setWebsockConnection(const WebSocketConnectionImplPtr &conn)
{
_websockConnPtr = conn;
}
@ -99,7 +99,7 @@ class HttpRequestParser
trantor::EventLoop *_loop;
HttpRequestImplPtr _request;
bool _firstRequest = true;
WebSocketConnectionPtr _websockConnPtr;
WebSocketConnectionImplPtr _websockConnPtr;
std::deque<std::pair<HttpRequestPtr, HttpResponsePtr>> _requestPipeLine;
size_t _requestsCounter = 0;
std::weak_ptr<trantor::TcpConnection> _conn;

View File

@ -122,6 +122,11 @@ class HttpResponseImpl : public HttpResponse
return getHeaderBy(key, defaultVal);
}
virtual const std::unordered_map<std::string, std::string> &headers() const override
{
return _headers;
}
const std::string &getHeaderBy(const std::string &lowerKey, const std::string &defaultVal = std::string()) const
{
auto iter = _headers.find(lowerKey);

View File

@ -118,8 +118,18 @@ bool HttpResponseParser::parseResponse(MsgBuffer *buf)
}
else
{
_state = HttpResponseParseState::kExpectClose;
hasMore = true;
if (_response->statusCode() == k101SwitchingProtocols &&
_response->getHeaderBy("upgrade") == "websocket")
{
//The Websocket response may not have a content-length header.
_state = HttpResponseParseState::kGotAll;
hasMore = false;
}
else
{
_state = HttpResponseParseState::kExpectClose;
hasMore = true;
}
}
}
}

View File

@ -27,80 +27,6 @@ using namespace std::placeholders;
using namespace drogon;
using namespace trantor;
// Return false if any error
static bool parseWebsockMessage(MsgBuffer *buffer, std::string &message)
{
assert(message.empty());
if (buffer->readableBytes() >= 2)
{
auto secondByte = (*buffer)[1];
size_t length = secondByte & 127;
int isMasked = (secondByte & 0x80);
if (isMasked != 0)
{
LOG_TRACE << "data encoded!";
}
else
LOG_TRACE << "plain data";
size_t indexFirstMask = 2;
if (length == 126)
{
indexFirstMask = 4;
}
else if (length == 127)
{
indexFirstMask = 10;
}
if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask)
{
if (indexFirstMask == 4)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
LOG_TRACE << "bytes[2]=" << (unsigned char)(*buffer)[2];
LOG_TRACE << "bytes[3]=" << (unsigned char)(*buffer)[3];
}
else if (indexFirstMask == 10)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
length = (length << 8) + (unsigned char)(*buffer)[4];
length = (length << 8) + (unsigned char)(*buffer)[5];
length = (length << 8) + (unsigned char)(*buffer)[6];
length = (length << 8) + (unsigned char)(*buffer)[7];
length = (length << 8) + (unsigned char)(*buffer)[8];
length = (length << 8) + (unsigned char)(*buffer)[9];
// length=*((uint64_t *)(buffer->peek()+2));
// length=ntohll(length);
}
else
{
LOG_ERROR << "Websock parsing failed!";
return false;
}
}
LOG_TRACE << "websocket message len=" << length;
if (buffer->readableBytes() >= (indexFirstMask + 4 + length))
{
auto masks = buffer->peek() + indexFirstMask;
int indexFirstDataByte = indexFirstMask + 4;
auto rawData = buffer->peek() + indexFirstDataByte;
message.resize(length);
LOG_TRACE << "rawData[0]=" << (unsigned char)rawData[0];
LOG_TRACE << "masks[0]=" << (unsigned char)masks[0];
for (size_t i = 0; i < length; i++)
{
message[i] = (rawData[i] ^ masks[i % 4]);
}
buffer->retrieve(indexFirstMask + 4 + length);
LOG_TRACE << "got message len=" << message.length();
return true;
}
}
return true;
}
static bool isWebSocket(const HttpRequestImplPtr &req)
{
auto &headers = req->headers();
@ -196,73 +122,83 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn,
int counter = 0;
// With the pipelining feature or web socket, it is possible to receice multiple messages at once, so
// the while loop is necessary
while (buf->readableBytes() > 0)
if (requestParser->webSocketConn())
{
if (requestParser->webSocketConn())
//Websocket payload
while (buf->readableBytes() > 0)
{
//Websocket payload
while (1)
std::string message;
WebSocketMessageType type;
auto success = parseWebsockMessage(buf, message, type);
if (success)
{
std::string message;
auto success = parseWebsockMessage(buf, message);
if (success)
if (type == WebSocketMessageType::Ping)
{
if (message.empty())
break;
else
{
_webSocketMessageCallback(requestParser->webSocketConn(), std::move(message));
}
//ping
requestParser->webSocketConn()->send(message, WebSocketMessageType::Pong);
}
else
else if (type == WebSocketMessageType::Close)
{
//Websock error!
//close
conn->shutdown();
return;
}
}
return;
}
if (requestParser->isStop())
{
//The number of requests has reached the limit.
buf->retrieveAll();
return;
}
if (!requestParser->parseRequest(buf))
{
requestParser->reset();
return;
}
if (requestParser->gotAll())
{
requestParser->requestImpl()->setPeerAddr(conn->peerAddr());
requestParser->requestImpl()->setLocalAddr(conn->localAddr());
requestParser->requestImpl()->setCreationDate(trantor::Date::date());
if (requestParser->firstReq() && isWebSocket(requestParser->requestImpl()))
{
auto wsConn = std::make_shared<WebSocketConnectionImpl>(conn);
_newWebsocketCallback(requestParser->requestImpl(),
[=](const HttpResponsePtr &resp) mutable {
if (resp->statusCode() == k101SwitchingProtocols)
{
requestParser->setWebsockConnection(wsConn);
}
auto httpString = std::dynamic_pointer_cast<HttpResponseImpl>(resp)->renderToString();
conn->send(httpString);
},
wsConn);
_webSocketMessageCallback(requestParser->webSocketConn(), std::move(message), type);
}
else
onRequest(conn, requestParser->requestImpl());
requestParser->reset();
counter++;
if (counter > 1)
LOG_TRACE << "More than one requests are parsed (" << counter << ")";
{
//Websock error!
conn->shutdown();
return;
}
}
else
return;
}
else
{
while (buf->readableBytes() > 0)
{
return;
if (requestParser->isStop())
{
//The number of requests has reached the limit.
buf->retrieveAll();
return;
}
if (!requestParser->parseRequest(buf))
{
requestParser->reset();
return;
}
if (requestParser->gotAll())
{
requestParser->requestImpl()->setPeerAddr(conn->peerAddr());
requestParser->requestImpl()->setLocalAddr(conn->localAddr());
requestParser->requestImpl()->setCreationDate(trantor::Date::date());
if (requestParser->firstReq() && isWebSocket(requestParser->requestImpl()))
{
auto wsConn = std::make_shared<WebSocketConnectionImpl>(conn);
_newWebsocketCallback(requestParser->requestImpl(),
[=](const HttpResponsePtr &resp) mutable {
if (resp->statusCode() == k101SwitchingProtocols)
{
requestParser->setWebsockConnection(wsConn);
}
auto httpString = std::dynamic_pointer_cast<HttpResponseImpl>(resp)->renderToString();
conn->send(httpString);
},
wsConn);
}
else
onRequest(conn, requestParser->requestImpl());
requestParser->reset();
counter++;
if (counter > 1)
LOG_TRACE << "More than one requests are parsed (" << counter << ")";
}
else
{
return;
}
}
}
}

View File

@ -39,7 +39,7 @@ class HttpServer : trantor::NonCopyable
WebSocketNewAsyncCallback;
typedef std::function<void(const WebSocketConnectionPtr &)>
WebSocketDisconnetCallback;
typedef std::function<void(const WebSocketConnectionPtr &, std::string &&message)>
typedef std::function<void(const WebSocketConnectionPtr &, std::string &&, const WebSocketMessageType &)>
WebSocketMessageCallback;
HttpServer(EventLoop *loop,

View File

@ -13,6 +13,8 @@
*/
#include "HttpUtils.h"
#include <drogon/utils/Utilities.h>
#include <trantor/utils/Logger.h>
namespace drogon
{
@ -373,4 +375,110 @@ const string_view &statusCodeToString(int code)
}
}
// Return false if any error
bool parseWebsockMessage(trantor::MsgBuffer *buffer, std::string &message, WebSocketMessageType &type)
{
assert(message.empty());
if (buffer->readableBytes() >= 2)
{
unsigned char opcode = (*buffer)[0] & 0x0f;
switch (opcode)
{
case 1:
type = WebSocketMessageType::Text;
break;
case 2:
type = WebSocketMessageType::Binary;
break;
case 8:
type = WebSocketMessageType::Close;
break;
case 9:
type = WebSocketMessageType::Ping;
break;
case 10:
type = WebSocketMessageType::Pong;
break;
default:
type = WebSocketMessageType::Unknown;
break;
}
auto secondByte = (*buffer)[1];
size_t length = secondByte & 127;
int isMasked = (secondByte & 0x80);
if (isMasked != 0)
{
LOG_TRACE << "data encoded!";
}
else
LOG_TRACE << "plain data";
size_t indexFirstMask = 2;
if (length == 126)
{
indexFirstMask = 4;
}
else if (length == 127)
{
indexFirstMask = 10;
}
if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask)
{
if (indexFirstMask == 4)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
}
else if (indexFirstMask == 10)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
length = (length << 8) + (unsigned char)(*buffer)[4];
length = (length << 8) + (unsigned char)(*buffer)[5];
length = (length << 8) + (unsigned char)(*buffer)[6];
length = (length << 8) + (unsigned char)(*buffer)[7];
length = (length << 8) + (unsigned char)(*buffer)[8];
length = (length << 8) + (unsigned char)(*buffer)[9];
// length=*((uint64_t *)(buffer->peek()+2));
// length=ntohll(length);
}
else
{
LOG_ERROR << "Websock parsing failed!";
return false;
}
}
if (isMasked != 0)
{
if (buffer->readableBytes() >= (indexFirstMask + 4 + length))
{
auto masks = buffer->peek() + indexFirstMask;
int indexFirstDataByte = indexFirstMask + 4;
auto rawData = buffer->peek() + indexFirstDataByte;
message.resize(length);
for (size_t i = 0; i < length; i++)
{
message[i] = (rawData[i] ^ masks[i % 4]);
}
buffer->retrieve(indexFirstMask + 4 + length);
LOG_TRACE << "got message len=" << message.length();
return true;
}
}
else
{
if (buffer->readableBytes() >= (indexFirstMask + length))
{
auto rawData = buffer->peek() + indexFirstMask;
message.append(rawData, length);
buffer->retrieve(indexFirstMask + length);
LOG_TRACE << "got message len=" << message.length();
return true;
}
}
}
return true;
}
} // namespace drogon

View File

@ -16,6 +16,9 @@
#include <string>
#include <drogon/HttpTypes.h>
#include <drogon/config.h>
#include <trantor/utils/MsgBuffer.h>
#include <drogon/WebSocketConnection.h>
#if CXX_STD >= 17
#include <string_view>
typedef std::string_view string_view;
@ -28,5 +31,6 @@ namespace drogon
const string_view &webContentTypeToString(ContentType contenttype);
const string_view &statusCodeToString(int code);
bool parseWebsockMessage(trantor::MsgBuffer *buffer, std::string &message, WebSocketMessageType &type);
} // namespace drogon

View File

@ -14,16 +14,42 @@
#include "WebSockectConnectionImpl.h"
#include <trantor/net/TcpConnection.h>
#include <thread>
using namespace drogon;
WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn)
WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer)
: _tcpConn(conn),
_localAddr(conn->localAddr()),
_peerAddr(conn->peerAddr())
_peerAddr(conn->peerAddr()),
_isServer(isServer)
{
}
void WebSocketConnectionImpl::send(const char *msg, uint64_t len)
void WebSocketConnectionImpl::send(const char *msg, uint64_t len, const WebSocketMessageType &type)
{
unsigned char opcode;
if (type == WebSocketMessageType::Text)
opcode = 1;
else if (type == WebSocketMessageType::Binary)
opcode = 2;
else if (type == WebSocketMessageType::Close)
opcode = 8;
else if (type == WebSocketMessageType::Ping)
opcode = 9;
else if (type == WebSocketMessageType::Pong)
opcode = 10;
else
{
opcode = 0;
assert(0);
}
sendWsData(msg, len, opcode);
}
void WebSocketConnectionImpl::sendWsData(const char *msg, size_t len, unsigned char opcode)
{
LOG_TRACE << "send " << len << " bytes";
auto conn = _tcpConn.lock();
if (conn)
@ -31,14 +57,13 @@ void WebSocketConnectionImpl::send(const char *msg, uint64_t len)
//Format the frame
std::string bytesFormatted;
bytesFormatted.resize(len + 10);
bytesFormatted[0] = char(129);
bytesFormatted[0] = char(0x80 | (opcode & 0x0f));
int indexStartRawData = -1;
if (len <= 125)
{
bytesFormatted[1] = len;
indexStartRawData = 2;
}
else if (len <= 65535)
@ -64,17 +89,35 @@ void WebSocketConnectionImpl::send(const char *msg, uint64_t len)
indexStartRawData = 10;
}
if (!_isServer)
{
//Add masking key;
static std::once_flag once;
std::call_once(once, []() {
std::srand(time(nullptr));
});
int random = std::rand();
bytesFormatted[1] = (bytesFormatted[1] | 0x80);
bytesFormatted.resize(indexStartRawData + 4 + len);
*((int *)&bytesFormatted[indexStartRawData]) = random;
for (size_t i = 0; i < len; i++)
{
bytesFormatted[indexStartRawData + 4 + i] = (msg[i] ^ bytesFormatted[indexStartRawData + (i % 4)]);
}
}
else
{
bytesFormatted.resize(indexStartRawData);
bytesFormatted.append(msg, len);
}
bytesFormatted.resize(indexStartRawData);
LOG_TRACE << "fheadlen=" << bytesFormatted.length();
bytesFormatted.append(msg, len);
LOG_TRACE << "send formatted frame len=" << len << " flen=" << bytesFormatted.length();
conn->send(bytesFormatted);
}
}
void WebSocketConnectionImpl::send(const std::string &msg)
void WebSocketConnectionImpl::send(const std::string &msg, const WebSocketMessageType &type)
{
send(msg.data(), msg.length());
send(msg.data(), msg.length(), type);
}
const trantor::InetAddress &WebSocketConnectionImpl::localAddr() const
{
@ -132,3 +175,4 @@ any *WebSocketConnectionImpl::WebSocketConnectionImpl::getMutableContext()
{
return &_context;
}

View File

@ -21,10 +21,10 @@ namespace drogon
class WebSocketConnectionImpl : public WebSocketConnection
{
public:
explicit WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn);
explicit WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer = true);
virtual void send(const char *msg, uint64_t len) override;
virtual void send(const std::string &msg) override;
virtual void send(const char *msg, uint64_t len, const WebSocketMessageType &type = WebSocketMessageType::Text) override;
virtual void send(const std::string &msg, const WebSocketMessageType &type = WebSocketMessageType::Text) override;
virtual const trantor::InetAddress &localAddr() const override;
virtual const trantor::InetAddress &peerAddr() const override;
@ -54,5 +54,10 @@ class WebSocketConnectionImpl : public WebSocketConnection
trantor::InetAddress _peerAddr;
WebSocketControllerBasePtr _ctrlPtr;
any _context;
bool _isServer = true;
void sendWsData(const char *msg, size_t len, unsigned char opcode);
};
typedef std::shared_ptr<WebSocketConnectionImpl> WebSocketConnectionImplPtr;
} // namespace drogon

View File

@ -0,0 +1,342 @@
/**
*
* WebSocketClientImpl.cc
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "WebSocketClientImpl.h"
#include "HttpRequestImpl.h"
#include "HttpUtils.h"
#include "HttpResponseParser.h"
#include <drogon/HttpAppFramework.h>
#include <trantor/net/InetAddress.h>
#include <drogon/utils/Utilities.h>
#ifdef USE_OPENSSL
#include <openssl/sha.h>
#else
#include "ssl_funcs/Sha1.h"
#endif
using namespace drogon;
using namespace trantor;
void WebSocketClientImpl::setHeartbeatMessage(const std::string &message, double interval)
{
std::weak_ptr<WebSocketClientImpl> weakPtr = shared_from_this();
_heartbeatTimerId = _loop->runEvery(interval, [weakPtr, message]() {
auto thisPtr = weakPtr.lock();
if (thisPtr && thisPtr->_websockConnPtr)
{
thisPtr->_websockConnPtr->send(message, WebSocketMessageType::Ping);
}
});
}
WebSocketClientImpl::~WebSocketClientImpl()
{
_loop->invalidateTimer(_heartbeatTimerId);
}
void WebSocketClientImpl::connectToServerInLoop()
{
_loop->assertInLoopThread();
_upgradeRequest->addHeader("Connection", "Upgrade");
_upgradeRequest->addHeader("Upgrade", "websocket");
auto randStr = utils::genRandomString(16);
_wsKey = utils::base64Encode((const unsigned char *)randStr.data(), (unsigned int)randStr.length());
auto wsKey = _wsKey;
wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
unsigned char accKey[SHA_DIGEST_LENGTH];
SHA1(reinterpret_cast<const unsigned char *>(wsKey.c_str()), wsKey.length(), accKey);
_wsAccept = utils::base64Encode(accKey, SHA_DIGEST_LENGTH);
_upgradeRequest->addHeader("Sec-WebSocket-Key", _wsKey);
//_upgradeRequest->addHeader("Sec-WebSocket-Version","13");
assert(!_tcpClient);
bool hasIpv6Address = false;
if (_server.isIpV6())
{
auto ipaddr = _server.ip6NetEndian();
for (int i = 0; i < 4; i++)
{
if (ipaddr[i] != 0)
{
hasIpv6Address = true;
break;
}
}
}
if (_server.ipNetEndian() == 0 && !hasIpv6Address &&
!_domain.empty() &&
_server.portNetEndian() != 0)
{
//dns
//TODO: timeout should be set by user
if (InetAddress::resolve(_domain, &_server) == false)
{
_requestCallback(ReqResult::BadServerAddress, nullptr, shared_from_this());
return;
}
LOG_TRACE << "dns:domain=" << _domain << ";ip=" << _server.toIp();
}
if ((_server.ipNetEndian() != 0 || hasIpv6Address) && _server.portNetEndian() != 0)
{
LOG_TRACE << "New TcpClient," << _server.toIpPort();
_tcpClient = std::make_shared<trantor::TcpClient>(_loop, _server, "httpClient");
#ifdef USE_OPENSSL
if (_useSSL)
{
_tcpClient->enableSSL();
}
#endif
auto thisPtr = shared_from_this();
std::weak_ptr<WebSocketClientImpl> weakPtr = thisPtr;
_tcpClient->setConnectionCallback([weakPtr](const trantor::TcpConnectionPtr &connPtr) {
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
if (connPtr->connected())
{
connPtr->setContext(HttpResponseParser(connPtr));
//send request;
LOG_TRACE << "Connection established!";
thisPtr->sendReq(connPtr);
}
else
{
LOG_TRACE << "connection disconnect";
thisPtr->_connectionClosedCallback(thisPtr);
thisPtr->_loop->runAfter(1.0, [thisPtr]() {
thisPtr->reconnect();
});
}
});
_tcpClient->setConnectionErrorCallback([weakPtr]() {
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
//can't connect to server
thisPtr->_requestCallback(ReqResult::NetworkFailure, nullptr, thisPtr);
thisPtr->_loop->runAfter(1.0, [thisPtr]() {
thisPtr->reconnect();
});
});
_tcpClient->setMessageCallback([weakPtr](const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msg) {
auto thisPtr = weakPtr.lock();
if (thisPtr)
{
thisPtr->onRecvMessage(connPtr, msg);
}
});
_tcpClient->connect();
}
else
{
_requestCallback(ReqResult::BadServerAddress, nullptr, shared_from_this());
return;
}
}
void WebSocketClientImpl::onRecvWsMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msgBuffer)
{
std::string message;
WebSocketMessageType type;
auto success = parseWebsockMessage(msgBuffer, message, type);
if (success)
{
if (type == WebSocketMessageType::Close)
{
//close
connPtr->shutdown();
}
else if (type == WebSocketMessageType::Ping)
{
//ping
if (_websockConnPtr)
{
_websockConnPtr->send(message, WebSocketMessageType::Pong);
}
}
_messageCallback(std::move(message), shared_from_this(), type);
}
else
{
//Websock error!
connPtr->shutdown();
auto thisPtr = shared_from_this();
_loop->runAfter(1.0, [thisPtr]() {
thisPtr->reconnect();
});
return;
}
}
void WebSocketClientImpl::onRecvMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msgBuffer)
{
if (_upgraded)
{
onRecvWsMessage(connPtr, msgBuffer);
return;
}
HttpResponseParser *responseParser = any_cast<HttpResponseParser>(connPtr->getMutableContext());
//LOG_TRACE << "###:" << msg->readableBytes();
if (!responseParser->parseResponse(msgBuffer))
{
_requestCallback(ReqResult::BadResponse, nullptr, shared_from_this());
connPtr->shutdown();
_websockConnPtr.reset();
_tcpClient.reset();
return;
}
if (responseParser->gotAll())
{
auto resp = responseParser->responseImpl();
responseParser->reset();
auto acceptStr = resp->getHeaderBy("sec-websocket-accept");
if (resp->statusCode() != k101SwitchingProtocols || acceptStr != _wsAccept)
{
_requestCallback(ReqResult::BadResponse, nullptr, shared_from_this());
connPtr->shutdown();
_websockConnPtr.reset();
_tcpClient.reset();
return;
}
auto &type = resp->getHeaderBy("content-type");
if (type.find("application/json") != std::string::npos)
{
resp->parseJson();
}
if (resp->getHeaderBy("content-encoding") == "gzip")
{
resp->gunzip();
}
_upgraded = true;
_websockConnPtr = std::make_shared<WebSocketConnectionImpl>(connPtr, false);
_requestCallback(ReqResult::Ok, resp, shared_from_this());
if (msgBuffer->readableBytes() > 0)
{
onRecvWsMessage(connPtr, msgBuffer);
}
}
else
{
return;
}
}
void WebSocketClientImpl::reconnect()
{
_tcpClient.reset();
_websockConnPtr.reset();
_upgraded = false;
connectToServerInLoop();
}
WebSocketClientImpl::WebSocketClientImpl(trantor::EventLoop *loop, const trantor::InetAddress &addr, bool useSSL)
: _loop(loop),
_server(addr),
_useSSL(useSSL)
{
}
WebSocketClientImpl::WebSocketClientImpl(trantor::EventLoop *loop, const std::string &hostString)
: _loop(loop)
{
auto lowerHost = hostString;
std::transform(lowerHost.begin(), lowerHost.end(), lowerHost.begin(), tolower);
if (lowerHost.find("wss://") != std::string::npos)
{
_useSSL = true;
lowerHost = lowerHost.substr(6);
}
else if (lowerHost.find("ws://") != std::string::npos)
{
_useSSL = false;
lowerHost = lowerHost.substr(5);
}
else
{
return;
}
auto pos = lowerHost.find(":");
if (pos != std::string::npos)
{
_domain = lowerHost.substr(0, pos);
auto portStr = lowerHost.substr(pos + 1);
pos = portStr.find("/");
if (pos != std::string::npos)
{
portStr = portStr.substr(0, pos);
}
auto port = atoi(portStr.c_str());
if (port > 0 && port < 65536)
{
_server = InetAddress(port);
}
}
else
{
_domain = lowerHost;
pos = _domain.find("/");
if (pos != std::string::npos)
{
_domain = _domain.substr(0, pos);
}
if (_useSSL)
{
_server = InetAddress(443);
}
else
{
_server = InetAddress(80);
}
}
LOG_TRACE << "userSSL=" << _useSSL << " domain=" << _domain;
}
void WebSocketClientImpl::sendReq(const trantor::TcpConnectionPtr &connPtr)
{
trantor::MsgBuffer buffer;
auto implPtr = std::dynamic_pointer_cast<HttpRequestImpl>(_upgradeRequest);
assert(implPtr);
implPtr->appendToBuffer(&buffer);
LOG_TRACE << "Send request:" << std::string(buffer.peek(), buffer.readableBytes());
connPtr->send(std::move(buffer));
}
WebSocketClientPtr WebSocketClient::newWebSocketClient(const std::string &ip,
uint16_t port,
bool useSSL,
trantor::EventLoop *loop)
{
bool isIpv6 = ip.find(":") == std::string::npos ? false : true;
return std::make_shared<WebSocketClientImpl>(loop == nullptr ? app().getLoop() : loop, trantor::InetAddress(ip, port, isIpv6), useSSL);
}
WebSocketClientPtr WebSocketClient::newWebSocketClient(const std::string &hostString,
trantor::EventLoop *loop)
{
return std::make_shared<WebSocketClientImpl>(loop == nullptr ? app().getLoop() : loop, hostString);
}

View File

@ -0,0 +1,102 @@
/**
*
* WebSocketClientImpl.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include "WebSockectConnectionImpl.h"
#include <drogon/WebSocketClient.h>
#include <trantor/utils/NonCopyable.h>
#include <trantor/net/EventLoop.h>
#include <trantor/net/TcpClient.h>
#include <string>
#include <memory>
namespace drogon
{
class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_from_this<WebSocketClientImpl>
{
public:
virtual const WebSocketConnectionPtr &getConnection() override
{
return _websockConnPtr;
}
virtual void setMessageHandler(const std::function<void(std::string &&message,
const WebSocketClientPtr &,
const WebSocketMessageType &)> &callback) override
{
_messageCallback = callback;
}
virtual void setConnectionClosedHandler(const std::function<void(const WebSocketClientPtr &)> &callback) override
{
_connectionClosedCallback = callback;
}
virtual void setHeartbeatMessage(const std::string &message, double interval) override;
virtual void connectToServer(const HttpRequestPtr &request, const WebSocketRequestCallback &callback) override
{
if (_loop->isInLoopThread())
{
_upgradeRequest = request;
_requestCallback = callback;
connectToServerInLoop();
}
else
{
auto thisPtr = shared_from_this();
_loop->queueInLoop([request, callback, thisPtr] {
thisPtr->_upgradeRequest = request;
thisPtr->_requestCallback = callback;
thisPtr->connectToServerInLoop();
});
}
}
virtual trantor::EventLoop *getLoop() override { return _loop; }
WebSocketClientImpl(trantor::EventLoop *loop, const trantor::InetAddress &addr, bool useSSL = false);
WebSocketClientImpl(trantor::EventLoop *loop, const std::string &hostString);
~WebSocketClientImpl();
private:
std::shared_ptr<trantor::TcpClient> _tcpClient;
trantor::EventLoop *_loop;
trantor::InetAddress _server;
std::string _domain;
bool _useSSL;
bool _upgraded = false;
std::string _wsKey;
std::string _wsAccept;
trantor::TimerId _heartbeatTimerId;
HttpRequestPtr _upgradeRequest;
std::function<void(std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &)> _messageCallback;
std::function<void(const WebSocketClientPtr &)> _connectionClosedCallback;
WebSocketRequestCallback _requestCallback;
WebSocketConnectionPtr _websockConnPtr;
void connectToServerInLoop();
void sendReq(const trantor::TcpConnectionPtr &connPtr);
void onRecvMessage(const trantor::TcpConnectionPtr &, trantor::MsgBuffer *);
void onRecvWsMessage(const trantor::TcpConnectionPtr &, trantor::MsgBuffer *);
void reconnect();
};
} // namespace drogon

View File

@ -24,10 +24,19 @@ if [ $? -ne 0 ];then
exit -1
fi
#Test WebSocket
./websocket_test -t
if [ $? -ne 0 ];then
echo "Error in testing"
exit -1
fi
killall -9 webapp
#Test drogon_ctl
rm -rf drogon_test
drogon_ctl create project drogon_test
cd drogon_test/controllers