Compare commits

...

18 Commits

Author SHA1 Message Date
Martin Chang
9d80aaa1e9 update README 2023-11-08 11:18:15 +08:00
Martin Chang
75f197ecd4 remove forced printing in debug 2023-11-08 11:04:20 +08:00
Martin Chang
8ce2b853d9 fix MSVC build 2023-11-08 10:59:08 +08:00
Martin Chang
781ef3c194 track window size 2023-11-08 10:43:19 +08:00
Martin Chang
3eddcafe73 send request body 2023-11-08 10:32:19 +08:00
Martin Chang
09a8634838 handle CR LF in header 2023-11-08 10:23:17 +08:00
Martin Chang
3ea63787e7 respect max concurrent streams 2023-11-08 10:20:20 +08:00
Martin Chang
7202330f10 reply to pings 2023-11-08 10:15:44 +08:00
Martin Chang
e27b800b33 handle out-of-order core frames 2023-11-08 10:06:36 +08:00
Martin Chang
98678dd331 handle response with no body 2023-11-08 09:49:54 +08:00
Martin Chang
f2a7ac8b2f handle error 2023-11-08 00:41:36 +08:00
Martin Chang
89786de0fe slight cleanup for error handling 2023-11-07 17:16:45 +08:00
Martin Chang
49cfea3b4a if content-length is present, cehck it matches the amount of data in DATA frame 2023-11-07 16:11:54 +08:00
Martin Chang
a3a6267577 fix 2023-11-07 15:50:11 +08:00
Martin Chang
69f592b726 release streamId on fail to decode header 2023-11-07 15:49:44 +08:00
Martin Chang
a0a9f7a337 enable handling of multiple streams 2023-11-07 15:48:54 +08:00
Martin Chang
1a27e8bf1e some minor improvments 2023-11-07 14:03:21 +08:00
Martin Chang
05a18675fe optimize frame serialization 2023-11-07 13:33:29 +08:00
5 changed files with 815 additions and 401 deletions

View File

@ -13,7 +13,7 @@ Drogon is a cross-platform framework, It supports Linux, macOS, FreeBSD, OpenBSD
* Use a non-blocking I/O network lib based on epoll (kqueue under macOS/FreeBSD) to provide high-concurrency, high-performance network IO, please visit the [TFB Tests Results](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=composite) for more details;
* Provide a completely asynchronous programming mode;
* Support Http1.0/1.1 (server side and client side);
* Support HTTP/2 (and 1.0/1.1) client and HTTP 1.1/1.0 server
* Based on template, a simple reflection mechanism is implemented to completely decouple the main program framework, controllers and views.
* Support cookies and built-in sessions;
* Support back-end rendering, the controller generates the data to the view to generate the Html page. Views are described by CSP template files, C++ codes are embedded into Html pages through CSP tags. And the drogon command-line tool automatically generates the C++ code files for compilation;

View File

@ -16,7 +16,7 @@ int main()
{
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
{
auto client = HttpClient::newHttpClient("https://clehaxze.tw:8844",
auto client = HttpClient::newHttpClient("https://clehaxze.tw",
nullptr,
false,
false);
@ -48,7 +48,7 @@ int main()
req->setParameter("wd", "wx");
req->setParameter("oq", "wx");
for (int i = 0; i < 1; ++i)
for (int i = 0; i < 2; ++i)
{
client->sendRequest(
req, [](ReqResult result, const HttpResponsePtr &response) {
@ -76,7 +76,6 @@ int main()
});
LOG_INFO << "send request";
}
}
app().run();
}
}

View File

@ -3,9 +3,93 @@
#include <variant>
using namespace drogon;
using namespace drogon::internal;
static const std::string_view h2_preamble = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
static std::vector<uint8_t> s2vec(const std::string &str)
{
std::vector<uint8_t> vec(str.size());
memcpy(vec.data(), str.data(), str.size());
return vec;
}
static std::optional<size_t> stosz(const std::string &str)
{
try
{
return std::stoull(str);
}
catch (const std::exception &e)
{
return std::nullopt;
}
}
enum class H2FrameType
{
Data = 0x0,
Headers = 0x1,
Priority = 0x2,
RstStream = 0x3,
Settings = 0x4,
PushPromise = 0x5,
Ping = 0x6,
GoAway = 0x7,
WindowUpdate = 0x8,
Continuation = 0x9,
AltSvc = 0xa,
// UNUSED = 0xb, // 0xb is removed from the spec
Origin = 0xc,
NumEntries
};
enum class H2SettingsKey
{
HeaderTableSize = 0x1,
EnablePush = 0x2,
MaxConcurrentStreams = 0x3,
InitialWindowSize = 0x4,
MaxFrameSize = 0x5,
MaxHeaderListSize = 0x6,
NumEntries
};
enum class H2HeadersFlags
{
EndStream = 0x1,
EndHeaders = 0x4,
Padded = 0x8,
Priority = 0x20
};
enum class H2DataFlags
{
EndStream = 0x1,
Padded = 0x8
};
enum class H2PingFlags
{
Ack = 0x1
};
static GoAwayFrame goAway(int32_t sid,
const std::string &msg,
StreamCloseErrorCode ec)
{
GoAwayFrame frame;
frame.additionalDebugData = s2vec(msg);
frame.errorCode = (uint32_t)ec;
frame.lastStreamId = sid;
return frame;
}
namespace drogon::internal
{
// Quick and dirty ByteStream implementation and extensions so we can use it
// to read from the buffer, safely. At least it checks for buffer overflows
// in debug mode.
@ -139,6 +223,23 @@ struct OByteStream
buffer.append((char *)ptr, size);
}
void overwriteU24BE(size_t offset, uint32_t value)
{
assert(value <= 0xffffff);
assert(offset <= buffer.readableBytes() - 3);
auto ptr = (uint8_t *)buffer.peek() + offset;
ptr[0] = value >> 16;
ptr[1] = value >> 8;
ptr[2] = value;
}
void overwriteU8(size_t offset, uint8_t value)
{
assert(offset <= buffer.readableBytes() - 1);
auto ptr = (uint8_t *)buffer.peek() + offset;
ptr[0] = value;
}
uint8_t *peek()
{
return (uint8_t *)buffer.peek();
@ -147,57 +248,7 @@ struct OByteStream
trantor::MsgBuffer buffer;
};
enum class H2FrameType
{
Data = 0x0,
Headers = 0x1,
Priority = 0x2,
RstStream = 0x3,
Settings = 0x4,
PushPromise = 0x5,
Ping = 0x6,
GoAway = 0x7,
WindowUpdate = 0x8,
Continuation = 0x9,
AltSvc = 0xa,
// UNUSED = 0xb, // 0xb is removed from the spec
Origin = 0xc,
NumEntries
};
enum class H2SettingsKey
{
HeaderTableSize = 0x1,
EnablePush = 0x2,
MaxConcurrentStreams = 0x3,
InitialWindowSize = 0x4,
MaxFrameSize = 0x5,
MaxHeaderListSize = 0x6,
NumEntries
};
enum class H2HeadersFlags
{
EndStream = 0x1,
EndHeaders = 0x4,
Padded = 0x8,
Priority = 0x20
};
enum class H2DataFlags
{
EndStream = 0x1,
Padded = 0x8
};
struct SettingsFrame
{
bool ack = false;
std::vector<std::pair<uint16_t, uint32_t>> settings;
static std::optional<SettingsFrame> parse(ByteStream &payload,
std::optional<SettingsFrame> SettingsFrame::parse(ByteStream &payload,
uint8_t flags)
{
if (payload.size() % 6 != 0)
@ -228,7 +279,7 @@ struct SettingsFrame
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
bool SettingsFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = (ack ? 0x1 : 0x0);
for (auto &[key, value] : settings)
@ -238,13 +289,8 @@ struct SettingsFrame
}
return true;
}
};
struct WindowUpdateFrame
{
uint32_t windowSizeIncrement = 0;
static std::optional<WindowUpdateFrame> parse(ByteStream &payload,
std::optional<WindowUpdateFrame> WindowUpdateFrame::parse(ByteStream &payload,
uint8_t flags)
{
if (payload.size() != 4)
@ -259,7 +305,7 @@ struct WindowUpdateFrame
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
bool WindowUpdateFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
if (windowSizeIncrement & (1U << 31))
@ -270,19 +316,9 @@ struct WindowUpdateFrame
stream.writeU32BE(windowSizeIncrement);
return true;
}
};
struct HeadersFrame
{
uint8_t padLength = 0;
bool exclusive = false;
uint32_t streamDependency = 0;
uint8_t weight = 0;
std::vector<uint8_t> headerBlockFragment;
bool endHeaders = false;
bool endStream = false;
static std::optional<HeadersFrame> parse(ByteStream &payload, uint8_t flags)
std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
uint8_t flags)
{
bool endStream = flags & (uint8_t)H2HeadersFlags::EndStream;
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
@ -323,7 +359,7 @@ struct HeadersFrame
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
bool HeadersFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
if (padLength > 0)
@ -348,15 +384,9 @@ struct HeadersFrame
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
return true;
}
};
struct GoAwayFrame
{
uint32_t lastStreamId = 0;
uint32_t errorCode = 0;
std::vector<uint8_t> additionalDebugData;
static std::optional<GoAwayFrame> parse(ByteStream &payload, uint8_t flags)
std::optional<GoAwayFrame> GoAwayFrame::parse(ByteStream &payload,
uint8_t flags)
{
if (payload.size() < 8)
{
@ -372,7 +402,7 @@ struct GoAwayFrame
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
bool GoAwayFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
stream.writeU32BE(lastStreamId);
@ -380,15 +410,8 @@ struct GoAwayFrame
stream.write(additionalDebugData.data(), additionalDebugData.size());
return true;
}
};
struct DataFrame
{
uint8_t padLength = 0;
std::vector<uint8_t> data;
bool endStream = false;
static std::optional<DataFrame> parse(ByteStream &payload, uint8_t flags)
std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
{
bool endStream = flags & (uint8_t)H2DataFlags::EndStream;
bool padded = flags & (uint8_t)H2DataFlags::Padded;
@ -416,7 +439,7 @@ struct DataFrame
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
bool DataFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
stream.write(data.data(), data.size());
@ -428,33 +451,50 @@ struct DataFrame
}
return true;
}
};
using H2Frame = std::variant<SettingsFrame,
WindowUpdateFrame,
HeadersFrame,
GoAwayFrame,
DataFrame>;
std::optional<PingFrame> PingFrame::parse(ByteStream &payload, uint8_t flags)
{
if (payload.size() != 8)
{
LOG_ERROR << "Invalid ping frame length";
return std::nullopt;
}
PingFrame frame;
payload.read(frame.opaqueData.data(), frame.opaqueData.size());
return frame;
}
bool PingFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = ack ? (uint8_t)H2PingFlags::Ack : 0x0;
stream.write(opaqueData.data(), opaqueData.size());
return true;
}
} // namespace drogon::internal
// Print the HEX and ASCII representation of the buffer side by side
// 16 bytes per line. Same function as the xdd command in linux.
static void dump_hex_beautiful(const void *ptr, size_t size)
static std::string dump_hex_beautiful(const void *ptr, size_t size)
{
std::stringstream ss;
ss << "\n";
for (size_t i = 0; i < size; i += 16)
{
printf("%08zx: ", i);
ss << std::setw(8) << std::setfill('0') << std::hex << i << ": ";
for (size_t j = 0; j < 16; ++j)
{
if (i + j < size)
{
printf("%02x ", ((unsigned char *)ptr)[i + j]);
ss << std::setw(2) << std::setfill('0') << std::hex
<< (int)((unsigned char *)ptr)[i + j] << " ";
}
else
{
printf(" ");
ss << " ";
}
}
printf(" ");
ss << " ";
for (size_t j = 0; j < 16; ++j)
{
if (i + j < size)
@ -462,22 +502,27 @@ static void dump_hex_beautiful(const void *ptr, size_t size)
if (((unsigned char *)ptr)[i + j] >= 32 &&
((unsigned char *)ptr)[i + j] < 127)
{
printf("%c", ((unsigned char *)ptr)[i + j]);
ss << (char)((unsigned char *)ptr)[i + j];
}
else
{
printf(".");
ss << ".";
}
}
}
printf("\n");
ss << "\n";
}
return ss.str();
}
static std::vector<uint8_t> serializeFrame(const H2Frame &frame,
size_t streamId)
static trantor::MsgBuffer serializeFrame(const H2Frame &frame, size_t streamId)
{
OByteStream buffer;
buffer.writeU24BE(0); // Placeholder for length
buffer.writeU8(0); // Placeholder for type
buffer.writeU8(0); // Placeholder for flags
buffer.writeU32BE(streamId);
uint8_t type;
uint8_t flags;
bool ok = false;
@ -505,6 +550,18 @@ static std::vector<uint8_t> serializeFrame(const H2Frame &frame,
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::GoAway;
}
else if (std::holds_alternative<DataFrame>(frame))
{
const auto &f = std::get<DataFrame>(frame);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::Data;
}
else if (std::holds_alternative<PingFrame>(frame))
{
const auto &f = std::get<PingFrame>(frame);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::Ping;
}
else
{
LOG_ERROR << "Unsupported frame type";
@ -517,23 +574,11 @@ static std::vector<uint8_t> serializeFrame(const H2Frame &frame,
abort();
}
std::vector<uint8_t> full_frame;
full_frame.reserve(9 + buffer.buffer.readableBytes());
size_t length = buffer.buffer.readableBytes();
assert(length <= 0xffffff);
full_frame.push_back(length >> 16);
full_frame.push_back(length >> 8);
full_frame.push_back(length);
full_frame.push_back(type);
full_frame.push_back(flags);
full_frame.push_back(streamId >> 24);
full_frame.push_back(streamId >> 16);
full_frame.push_back(streamId >> 8);
full_frame.push_back(streamId);
full_frame.insert(full_frame.end(),
buffer.buffer.peek(),
buffer.buffer.peek() + length);
return full_frame;
auto length = buffer.buffer.readableBytes() - 9;
buffer.overwriteU24BE(0, length);
buffer.overwriteU8(3, type);
buffer.overwriteU8(4, flags);
return buffer.buffer;
}
// return streamId, frame, error and should continue parsing
@ -588,6 +633,8 @@ static std::tuple<std::optional<H2Frame>, size_t, uint8_t, bool> parseH2Frame(
frame = HeadersFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::Data)
frame = DataFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::Ping)
frame = PingFrame::parse(payload, flags);
else
{
LOG_WARN << "Unsupported H2 frame type: " << (int)type;
@ -611,18 +658,32 @@ static std::tuple<std::optional<H2Frame>, size_t, uint8_t, bool> parseH2Frame(
void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
HttpReqCallback &&callback)
{
connPtr->getLoop()->assertInLoopThread();
if (!serverSettingsReceived)
{
bufferedRequests.emplace_back(req, std::move(callback));
bufferedRequests.push({req, std::move(callback)});
return;
}
if (streams.size() >= maxConcurrentStreams)
{
LOG_TRACE << "Too many streams in flight. Buffering request";
bufferedRequests.push({req, std::move(callback)});
return;
}
const int32_t streamId = nextStreamId();
assert(streamId % 2 == 1);
LOG_TRACE << "Sending HTTP/2 request: streamId=" << streamId;
if (streams.find(streamId) != streams.end())
{
LOG_FATAL << "Stream id already in use! This should not happen";
abort();
connPtr->send(
serializeFrame(goAway(streamId,
"replicated internal stream id",
StreamCloseErrorCode::InternalError),
0));
errorCallback(ReqResult::BadResponse);
return;
}
@ -664,16 +725,46 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
abort();
return;
}
// TODO: Send CONTINUATION frames if the header block fragment is too large
if (n > 0x7fff)
{
LOG_ERROR << "Header block fragment too large";
abort();
return;
}
frame.headerBlockFragment.resize(n);
frame.endHeaders = true;
auto &stream = createStream(streamId);
if (req->body().length() == 0)
frame.endStream = true;
LOG_TRACE << "Sending headers frame";
auto f = serializeFrame(frame, streamId);
dump_hex_beautiful(f.data(), f.size());
connPtr->send(f.data(), f.size());
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
connPtr->send(f);
streams[streamId] = internal::H2Stream();
streams[streamId].callback = std::move(callback);
stream.callback = std::move(callback);
stream.request = req;
// TODO: Don't dump the entire body into TCP at once
if (req->body().length() == 0)
{
LOG_TRACE << "No body to send";
return;
}
DataFrame dataFrame;
for (size_t i = 0; i < req->body().length(); i += maxFrameSize)
{
size_t readSize = (std::min)(maxFrameSize, req->body().length() - i);
std::vector<uint8_t> buffer;
buffer.resize(readSize);
memcpy(buffer.data(), req->body().data() + i, readSize);
dataFrame.data = std::move(buffer);
dataFrame.endStream = (i + maxFrameSize >= req->body().length());
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
<< " endStream=" << dataFrame.endStream;
connPtr->send(serializeFrame(dataFrame, streamId));
}
}
void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
@ -682,9 +773,16 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
LOG_TRACE << "HTTP/2 message received:";
assert(bytesReceived_ != nullptr);
*bytesReceived_ += msg->readableBytes();
dump_hex_beautiful(msg->peek(), msg->readableBytes());
LOG_TRACE << dump_hex_beautiful(msg->peek(), msg->readableBytes());
while (true)
{
if (avaliableWindowSize < windowIncreaseThreshold)
{
WindowUpdateFrame windowUpdateFrame;
windowUpdateFrame.windowSizeIncrement = windowIncreaseSize;
connPtr->send(serializeFrame(windowUpdateFrame, 0));
}
// FIXME: The code cannot distinguish between a out-of-data and
// unsupported frame type. We need to fix this as it should be handled
// differently.
@ -704,15 +802,36 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
}
auto &frame = *frameOpt;
// TODO: Figure out how to dispatch the frame to the right stream
// special case for PING frame. It is the only frame that is not
// associated with a stream
if (std::holds_alternative<PingFrame>(frame))
{
auto &f = std::get<PingFrame>(frame);
if (f.ack)
{
LOG_TRACE << "Ping frame received with ACK flag set";
continue;
}
LOG_TRACE << "Ping frame received. Sending ACK";
PingFrame ackFrame;
ackFrame.ack = true;
ackFrame.opaqueData = f.opaqueData;
connPtr->send(serializeFrame(ackFrame, 0));
continue;
}
if (streamId != 0)
{
handleFrameForStream(frame, streamId, flags);
continue;
}
// This point forawrd, we are handling frames for stream 0
if (std::holds_alternative<WindowUpdateFrame>(frame))
{
auto &f = std::get<WindowUpdateFrame>(frame);
if (streamId == 0)
{
avaliableWindowSize += f.windowSizeIncrement;
}
}
else if (std::holds_alternative<SettingsFrame>(frame))
{
auto &f = std::get<SettingsFrame>(frame);
@ -720,11 +839,15 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
{
if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
{
if (streamId == 0)
if (value == 0)
{
connPtr->send(serializeFrame(
goAway(streamId,
"MaxConcurrentStreams cannot be 0",
StreamCloseErrorCode::ProtocolError),
0));
}
maxConcurrentStreams = value;
else
LOG_TRACE << "Ignoring max concurrent streams due to "
"streamId != 0";
}
else if (key == (uint16_t)H2SettingsKey::MaxFrameSize)
{
@ -736,26 +859,27 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
}
}
if (streamId == 0 && !serverSettingsReceived)
if (!serverSettingsReceived)
{
LOG_TRACE << "Server settings received. Sending our own "
"settings and WindowUpdate";
SettingsFrame settingsFrame;
settingsFrame.settings.emplace_back(
(uint16_t)H2SettingsKey::EnablePush, 0); // Disable push
auto b = serializeFrame(settingsFrame, 0);
connPtr->send((const char *)b.data(), b.size());
connPtr->send(serializeFrame(settingsFrame, 0));
WindowUpdateFrame windowUpdateFrame;
// TODO: Keep track and update the window size
windowUpdateFrame.windowSizeIncrement = 200 * 1024 * 1024;
auto b2 = serializeFrame(windowUpdateFrame, 0);
connPtr->send((const char *)b2.data(), b2.size());
windowUpdateFrame.windowSizeIncrement = windowIncreaseSize;
connPtr->send(serializeFrame(windowUpdateFrame, 0));
serverSettingsReceived = true;
for (auto &[req, cb] : bufferedRequests)
while (!bufferedRequests.empty() &&
streams.size() < maxConcurrentStreams)
{
auto &[req, cb] = bufferedRequests.front();
sendRequestInLoop(req, std::move(cb));
bufferedRequests.clear();
bufferedRequests.pop();
}
}
// Somehow nghttp2 wants us to send ACK after sending our
@ -765,92 +889,38 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
LOG_TRACE << "Acknowledge settings frame";
SettingsFrame ackFrame;
ackFrame.ack = true;
auto b = serializeFrame(ackFrame, 0);
connPtr->send((const char *)b.data(), b.size());
connPtr->send(serializeFrame(ackFrame, streamId));
}
else if (std::holds_alternative<HeadersFrame>(frame))
{
auto &f = std::get<HeadersFrame>(frame);
LOG_TRACE << "Headers frame received: size="
<< f.headerBlockFragment.size();
hpack::HPacker::KeyValueVector headers;
int n = hpackRx.decode(f.headerBlockFragment.data(),
f.headerBlockFragment.size(),
headers);
if (n < 0)
{
LOG_ERROR << "Failed to decode headers";
abort();
return;
}
for (auto &[key, value] : headers)
LOG_TRACE << " " << key << ": " << value;
auto it = streams.find(streamId);
if (it == streams.end())
{
LOG_ERROR << "Headers frame received for unknown stream id: "
<< streamId;
// TODO: Send GoAway frame
return;
}
it->second.response = std::make_shared<HttpResponseImpl>();
for (const auto &[key, value] : headers)
{
// TODO: Filter more pseudo headers
if (key == ":status")
continue;
// TODO: Validate content-length is either not present or
// the same as the body size sent by DATA frames
if (key == "content-length")
continue;
if (key == ":status")
{
// TODO: Validate status code
it->second.response->setStatusCode(
(drogon::HttpStatusCode)std::stoi(value));
continue;
}
it->second.response->addHeader(key, value);
}
// Should never show up on stream 0
LOG_FATAL << "Protocol error: HEADERS frame on stream 0";
errorCallback(ReqResult::BadResponse);
}
else if (std::holds_alternative<DataFrame>(frame))
{
auto &f = std::get<DataFrame>(frame);
LOG_TRACE << "Data frame received: size=" << f.data.size();
auto it = streams.find(streamId);
if (it == streams.end())
{
LOG_ERROR << "Data frame received for unknown stream id: "
<< streamId;
return;
}
it->second.body.append((char *)f.data.data(), f.data.size());
if ((flags & (uint8_t)H2DataFlags::EndStream) != 0)
{
// TODO: Optmize setting body
std::string body(it->second.body.peek(),
it->second.body.readableBytes());
auto headers = it->second.response->headers();
it->second.response->setBody(std::move(body));
// FIXME: Store the actuall request object
auto req = HttpRequest::newHttpRequest();
respCallback(it->second.response,
{req, it->second.callback},
connPtr);
}
LOG_FATAL << "Protocol error: DATA frame on stream 0";
errorCallback(ReqResult::BadResponse);
}
else if (std::holds_alternative<GoAwayFrame>(frame))
{
LOG_ERROR << "Go away frame received. Die!";
auto &f = std::get<GoAwayFrame>(frame);
// TODO: Depening on the streamId, we need to kill the entire
// connection or just the stream
if (f.errorCode != 0)
{
LOG_ERROR << "Go away frame on stream 0 received. Die!";
errorCallback(ReqResult::BadResponse);
}
else
{
// We shouldn't have any requests in flight if the server is
// sending us a go away frame to gracefully shutdown
// But in case we do, we should treat them as network failures
assert(streams.empty());
errorCallback(ReqResult::NetworkFailure);
}
connPtr->shutdown();
}
else
{
// TODO: Remove this once we support all frame types
// in that case it'll be a parsing error or bad server
@ -866,3 +936,229 @@ Http2Transport::Http2Transport(trantor::TcpConnectionPtr connPtr,
{
connPtr->send(h2_preamble.data(), h2_preamble.length());
}
void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
int32_t streamId,
uint8_t flags)
{
auto it = streams.find(streamId);
if (it == streams.end())
{
LOG_ERROR << "Non-existent stream id: " << streamId;
connPtr->send(serializeFrame(
goAway(streamId,
"Non-existent stream id " + std::to_string(streamId),
StreamCloseErrorCode::ProtocolError),
0));
return;
}
auto &stream = it->second;
if (std::holds_alternative<HeadersFrame>(frame))
{
if (stream.state != StreamState::ExpectingHeaders)
{
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError,
"Unexpected headers frame");
return;
}
auto &f = std::get<HeadersFrame>(frame);
LOG_TRACE << "Headers frame received: size="
<< f.headerBlockFragment.size();
hpack::HPacker::KeyValueVector headers;
int n = hpackRx.decode(f.headerBlockFragment.data(),
f.headerBlockFragment.size(),
headers);
if (n < 0)
{
LOG_ERROR << "Failed to decode headers";
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::CompressionError,
"Failed to decode headers");
return;
}
for (auto &[key, value] : headers)
LOG_TRACE << " " << key << ": " << value;
it->second.response = std::make_shared<HttpResponseImpl>();
for (const auto &[key, value] : headers)
{
// TODO: Filter more pseudo headers
if (key == "content-length")
{
auto sz = stosz(value);
if (!sz)
{
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError,
"Invalid content-length header");
return;
}
it->second.contentLength = std::move(sz);
}
if (key == ":status")
{
// TODO: Validate status code
it->second.response->setStatusCode(
(drogon::HttpStatusCode)std::stoi(value));
continue;
}
// Anti request smuggling. We look for \r or \n in the header
// name or value. If we find one, we abort the stream.
if (key.find_first_of("\r\n") != std::string::npos ||
value.find_first_of("\r\n") != std::string::npos)
{
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError,
"CR or LF found in header name or value");
return;
}
it->second.response->addHeader(key, value);
}
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
{
LOG_ERROR << "We don't support CONTINUATION frames yet!";
stream.state = StreamState::ExpectingContinuation;
abort();
}
// There is no body in the response.
if ((flags & (uint8_t)H2HeadersFlags::EndStream))
{
stream.state = StreamState::Finished;
streamFinished(stream);
return;
}
stream.state = StreamState::ExpectingData;
}
else if (std::holds_alternative<DataFrame>(frame))
{
auto &f = std::get<DataFrame>(frame);
// TODO: Make sure this logic fits RFC
if (f.data.size() > avaliableWindow)
{
LOG_TRACE << "Data frame received: size=" << f.data.size()
<< " but avaliableWindow=" << avaliableWindow;
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::FlowControlError,
"Too much data");
}
avaliableWindowSize -= f.data.size();
if (stream.state != StreamState::ExpectingData)
{
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError,
"Unexpected data frame");
return;
}
LOG_TRACE << "Data frame received: size=" << f.data.size();
stream.body.append((char *)f.data.data(), f.data.size());
if ((flags & (uint8_t)H2DataFlags::EndStream) != 0)
{
if (stream.contentLength &&
stream.body.readableBytes() != *stream.contentLength)
{
LOG_ERROR << "Content-length mismatch";
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError,
"Content-length mismatch");
return;
}
// TODO: Optmize setting body
std::string body(stream.body.peek(), stream.body.readableBytes());
stream.response->setBody(std::move(body));
streamFinished(stream);
return;
}
}
else if (std::holds_alternative<GoAwayFrame>(frame))
{
auto &f = std::get<GoAwayFrame>(frame);
LOG_TRACE << "Go away frame received: lastStreamId=" << f.lastStreamId
<< " errorCode=" << f.errorCode << " additionalDebugData="
<< std::string(f.additionalDebugData.begin(),
f.additionalDebugData.end());
stream.callback(ReqResult::BadResponse, nullptr);
}
else
{
LOG_ERROR << "Unsupported frame type for stream: " << streamId;
}
}
internal::H2Stream &Http2Transport::createStream(int32_t streamId)
{
auto it = streams.find(streamId);
if (it != streams.end())
{
LOG_FATAL << "Stream id already in use! This should not happen";
abort();
}
auto &stream = streams[streamId];
stream.streamId = streamId;
return stream;
}
void Http2Transport::streamFinished(internal::H2Stream &stream)
{
assert(stream.request != nullptr);
assert(stream.callback);
auto it = streams.find(stream.streamId);
assert(it != streams.end());
respCallback(stream.response, {stream.request, stream.callback}, connPtr);
streams.erase(it);
retireStreamId(stream.streamId);
}
void Http2Transport::streamFinished(int32_t streamId,
ReqResult result,
StreamCloseErrorCode errorCode,
std::string errorMsg)
{
auto it = streams.find(streamId);
assert(it != streams.end());
connPtr->send(serializeFrame(goAway(streamId, errorMsg, errorCode), 0));
it->second.callback(result, nullptr);
streams.erase(it);
retireStreamId(streamId);
if (bufferedRequests.empty())
return;
auto &[req, cb] = bufferedRequests.front();
sendRequestInLoop(req, std::move(cb));
bufferedRequests.pop();
}
void Http2Transport::onError(ReqResult result)
{
connPtr->getLoop()->assertInLoopThread();
for (auto &[streamId, stream] : streams)
stream.callback(result, nullptr);
streams.clear();
while (!bufferedRequests.empty())
{
auto &[req, cb] = bufferedRequests.front();
cb(result, nullptr);
bufferedRequests.pop();
}
if (bufferedRequests.empty())
return;
auto &[req, cb] = bufferedRequests.front();
sendRequestInLoop(req, std::move(cb));
bufferedRequests.pop();
}

View File

@ -5,11 +5,93 @@
// TOOD: Write our own HPACK implementation
#include "hpack/HPacker.h"
#include <variant>
namespace drogon
{
namespace internal
{
struct ByteStream;
struct OByteStream;
struct SettingsFrame
{
bool ack = false;
std::vector<std::pair<uint16_t, uint32_t>> settings;
static std::optional<SettingsFrame> parse(ByteStream &payload,
uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct WindowUpdateFrame
{
uint32_t windowSizeIncrement = 0;
static std::optional<WindowUpdateFrame> parse(ByteStream &payload,
uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct HeadersFrame
{
uint8_t padLength = 0;
bool exclusive = false;
uint32_t streamDependency = 0;
uint8_t weight = 0;
std::vector<uint8_t> headerBlockFragment;
bool endHeaders = false;
bool endStream = false;
static std::optional<HeadersFrame> parse(ByteStream &payload,
uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct GoAwayFrame
{
uint32_t lastStreamId = 0;
uint32_t errorCode = 0;
std::vector<uint8_t> additionalDebugData;
static std::optional<GoAwayFrame> parse(ByteStream &payload, uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct DataFrame
{
uint8_t padLength = 0;
std::vector<uint8_t> data;
bool endStream = false;
static std::optional<DataFrame> parse(ByteStream &payload, uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct PingFrame
{
std::array<uint8_t, 8> opaqueData;
bool ack = false;
static std::optional<PingFrame> parse(ByteStream &payload, uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
using H2Frame = std::variant<SettingsFrame,
WindowUpdateFrame,
HeadersFrame,
GoAwayFrame,
DataFrame,
PingFrame>;
enum class StreamState
{
ExpectingHeaders,
ExpectingContinuation,
ExpectingData,
Finished,
};
// Virtual stream that holds properties for the HTTP/2 stream
// Defaults to stream 0 global properties
@ -17,11 +99,32 @@ struct H2Stream
{
HttpReqCallback callback;
HttpResponseImplPtr response;
HttpRequestPtr request;
trantor::MsgBuffer body;
std::optional<size_t> contentLength;
int32_t streamId = 0;
StreamState state = StreamState::ExpectingHeaders;
};
} // namespace internal
enum class StreamCloseErrorCode
{
NoError = 0x0,
ProtocolError = 0x1,
InternalError = 0x2,
FlowControlError = 0x3,
SettingsTimeout = 0x4,
StreamClosed = 0x5,
FrameSizeError = 0x6,
RefusedStream = 0x7,
Cancel = 0x8,
CompressionError = 0x9,
ConnectError = 0xa,
EnhanceYourCalm = 0xb,
InadequateSecurity = 0xc,
Http11Required = 0xd,
};
class Http2Transport : public HttpTransport
{
private:
@ -42,9 +145,21 @@ class Http2Transport : public HttpTransport
size_t maxFrameSize = 16384;
size_t avaliableWindowSize = 0;
// Configuration settings
const size_t windowIncreaseThreshold = 32768;
const size_t windowIncreaseSize = 10 * 1024 * 1024; // 10 MiB
// Set after server settings are received
bool serverSettingsReceived = false;
std::vector<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
std::queue<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
size_t avaliableWindow = 10 * 1024 * 1024; // 10 MiB
internal::H2Stream &createStream(int32_t streamId);
void streamFinished(internal::H2Stream &stream);
void streamFinished(int32_t streamId,
ReqResult result,
StreamCloseErrorCode errorCode,
std::string errorMsg = "");
int32_t nextStreamId()
{
@ -78,6 +193,10 @@ class Http2Transport : public HttpTransport
}
}
void handleFrameForStream(const internal::H2Frame &frame,
int32_t streamId,
uint8_t flags);
public:
Http2Transport(trantor::TcpConnectionPtr connPtr,
size_t *bytesSent,
@ -99,10 +218,7 @@ class Http2Transport : public HttpTransport
"HTTP/2 handleConnectionClose not implemented");
}
void onError(ReqResult result) override
{
throw std::runtime_error("HTTP/2 onError not implemented");
}
void onError(ReqResult result) override;
protected:
void onServerSettingsReceived(){};

View File

@ -109,6 +109,9 @@ const char *HttpResponseImpl::versionString() const
case Version::kHttp11:
result = "HTTP/1.1";
break;
case Version::kHttp2:
result = "HTTP/2";
break;
default:
break;