Compare commits

...

12 Commits

Author SHA1 Message Date
Martin Chang
ab73516ae1 fix msvc build 2023-11-07 10:59:26 +08:00
Martin Chang
ffe1f6b022 fix more clang warning 2023-11-07 10:43:18 +08:00
Martin Chang
0c7e4de89c fix clang compile warnings 2023-11-07 10:41:24 +08:00
Martin Chang
910283265d cleanup parsing API 2023-11-07 10:39:21 +08:00
Martin Chang
fc4b669ad6 Support serialized GoAwau 2023-11-06 22:44:45 +08:00
Martin Chang
37d44b44bb workaround clang compile error 2023-11-06 22:32:10 +08:00
Martin Chang
21561979c8 fix not decompressing HTTP response 2023-11-06 22:26:05 +08:00
Martin Chang
51b68dc711 initial populate response data and header 2023-11-06 22:16:18 +08:00
Martin Chang
1e62aae76e format 2023-11-06 21:54:11 +08:00
Martin Chang
de139b1448 single resp callback called 2023-11-06 21:13:52 +08:00
Martin Chang
8fed7603cb parse data frames 2023-11-06 19:36:43 +08:00
Martin Chang
1165f35c00 slight cleanup, parse h2 header 2023-11-06 17:10:09 +08:00
6 changed files with 516 additions and 191 deletions

View File

@ -39,6 +39,18 @@ struct ByteStream
return res;
}
std::pair<bool, int32_t> readBI32BE()
{
assert(offset <= length - 4);
int32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
ptr[offset + 2] << 8 | ptr[offset + 3];
offset += 4;
constexpr int32_t mask = 0x7fffffff;
bool flag = res & (~mask);
res &= mask;
return {flag, res};
}
uint16_t readU16BE()
{
assert(offset <= length - 2);
@ -53,9 +65,29 @@ struct ByteStream
return ptr[offset++];
}
void read(uint8_t *buffer, size_t size)
{
assert(offset <= length - size || size == 0);
memcpy(buffer, ptr + offset, size);
offset += size;
}
void read(std::vector<uint8_t> &buffer, size_t size)
{
buffer.resize(buffer.size() + size);
read(buffer.data(), size);
}
std::vector<uint8_t> read(size_t size)
{
std::vector<uint8_t> buffer;
read(buffer, size);
return buffer;
}
void skip(size_t n)
{
assert(offset <= length - n);
assert(offset <= length - n || n == 0);
offset += n;
}
@ -75,6 +107,46 @@ struct ByteStream
size_t offset = 0;
};
// DITTO but for serialization
struct OByteStream
{
void writeU24BE(uint32_t value)
{
assert(value <= 0xffffff);
value = htonl(value);
buffer.append((char *)&value + 1, 3);
}
void writeU32BE(uint32_t value)
{
value = htonl(value);
buffer.append((char *)&value, 4);
}
void writeU16BE(uint16_t value)
{
value = htons(value);
buffer.append((char *)&value, 2);
}
void writeU8(uint8_t value)
{
buffer.append((char *)&value, 1);
}
void write(const uint8_t *ptr, size_t size)
{
buffer.append((char *)ptr, size);
}
uint8_t *peek()
{
return (uint8_t *)buffer.peek();
}
trantor::MsgBuffer buffer;
};
enum class H2FrameType
{
Data = 0x0,
@ -106,14 +178,98 @@ enum class H2SettingsKey
NumEntries
};
enum class H2HeadersFlags
{
EndStream = 0x1,
EndHeaders = 0x4,
Padded = 0x8,
Priority = 0x20
};
enum class H2DataFlags
{
EndStream = 0x1,
Padded = 0x8
};
struct SettingsFrame
{
bool ack = false;
std::vector<std::pair<uint16_t, uint32_t>> settings;
static std::optional<SettingsFrame> parse(ByteStream &payload,
uint8_t flags)
{
if (payload.size() % 6 != 0)
{
LOG_ERROR << "Invalid settings frame length";
return std::nullopt;
}
SettingsFrame frame;
if ((flags & 0x1) != 0)
{
frame.ack = true;
if (payload.size() != 0)
{
LOG_ERROR << "Settings frame with ACK flag set should have "
"empty payload";
return std::nullopt;
}
return frame;
}
for (size_t i = 0; i < payload.size(); i += 6)
{
uint16_t key = payload.readU16BE();
uint32_t value = payload.readU32BE();
frame.settings.emplace_back(key, value);
}
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
{
flags = (ack ? 0x1 : 0x0);
for (auto &[key, value] : settings)
{
stream.writeU16BE(key);
stream.writeU32BE(value);
}
return true;
}
};
struct WindowUpdateFrame
{
uint32_t windowSizeIncrement;
uint32_t windowSizeIncrement = 0;
static std::optional<WindowUpdateFrame> parse(ByteStream &payload,
uint8_t flags)
{
if (payload.size() != 4)
{
LOG_ERROR << "Invalid window update frame length";
return std::nullopt;
}
WindowUpdateFrame frame;
// MSB is reserved for future use
auto [_, windowSizeIncrement] = payload.readBI32BE();
frame.windowSizeIncrement = windowSizeIncrement;
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
if (windowSizeIncrement & (1U << 31))
{
LOG_ERROR << "MSB of windowSizeIncrement should be 0";
return false;
}
stream.writeU32BE(windowSizeIncrement);
return true;
}
};
struct HeadersFrame
@ -123,17 +279,162 @@ struct HeadersFrame
uint32_t streamDependency = 0;
uint8_t weight = 0;
std::vector<uint8_t> headerBlockFragment;
bool endHeaders = false;
bool endStream = false;
static std::optional<HeadersFrame> parse(ByteStream &payload, uint8_t flags)
{
bool endStream = flags & (uint8_t)H2HeadersFlags::EndStream;
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
bool priority = flags & (uint8_t)H2HeadersFlags::Priority;
HeadersFrame frame;
if (padded)
{
frame.padLength = payload.readU8();
}
if (priority)
{
auto [exclusive, streamDependency] = payload.readBI32BE();
frame.exclusive = exclusive;
frame.streamDependency = streamDependency;
frame.weight = payload.readU8();
}
if (endHeaders)
{
frame.endHeaders = true;
}
if (endStream)
{
frame.endStream = true;
}
int64_t payloadSize = payload.remaining() - frame.padLength;
if (payloadSize < 0)
{
LOG_ERROR << "headers padding is larger than the payload size";
return std::nullopt;
}
frame.headerBlockFragment.resize(payloadSize);
payload.read(frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
payload.skip(frame.padLength);
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
if (padLength > 0)
{
flags |= (uint8_t)H2HeadersFlags::Padded;
stream.writeU8(padLength);
}
if (exclusive)
{
flags |= (uint8_t)H2HeadersFlags::Priority;
uint32_t streamDependency = this->streamDependency;
if (exclusive)
streamDependency |= 1U << 31;
stream.writeU32BE(streamDependency);
stream.writeU8(weight);
}
if (endHeaders)
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
if (endStream)
flags |= (uint8_t)H2HeadersFlags::EndStream;
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
return true;
}
};
struct GoAwayFrame
{
uint32_t lastStreamId;
uint32_t errorCode;
uint32_t lastStreamId = 0;
uint32_t errorCode = 0;
std::vector<uint8_t> additionalDebugData;
static std::optional<GoAwayFrame> parse(ByteStream &payload, uint8_t flags)
{
if (payload.size() < 8)
{
LOG_ERROR << "Invalid go away frame length";
return std::nullopt;
}
GoAwayFrame frame;
frame.lastStreamId = payload.readU32BE();
frame.errorCode = payload.readU32BE();
frame.additionalDebugData.resize(payload.remaining());
for (size_t i = 0; i < frame.additionalDebugData.size(); ++i)
frame.additionalDebugData[i] = payload.readU8();
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
stream.writeU32BE(lastStreamId);
stream.writeU32BE(errorCode);
stream.write(additionalDebugData.data(), additionalDebugData.size());
return true;
}
};
using H2Frame =
std::variant<SettingsFrame, WindowUpdateFrame, HeadersFrame, GoAwayFrame>;
struct DataFrame
{
uint8_t padLength = 0;
std::vector<uint8_t> data;
bool endStream = false;
static std::optional<DataFrame> parse(ByteStream &payload, uint8_t flags)
{
bool endStream = flags & (uint8_t)H2DataFlags::EndStream;
bool padded = flags & (uint8_t)H2DataFlags::Padded;
DataFrame frame;
if (padded)
{
frame.padLength = payload.readU8();
}
if (endStream)
{
frame.endStream = true;
}
int32_t payloadSize = payload.remaining() - frame.padLength;
if (payloadSize < 0)
{
LOG_ERROR << "data padding is larger than the payload size";
return std::nullopt;
}
frame.data.resize(payloadSize);
payload.read(frame.data.data(), frame.data.size());
payload.skip(frame.padLength);
return frame;
}
bool serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
stream.write(data.data(), data.size());
if (padLength > 0)
{
flags |= (uint8_t)H2DataFlags::Padded;
for (size_t i = 0; i < padLength; ++i)
stream.writeU8(0x0);
}
return true;
}
};
using H2Frame = std::variant<SettingsFrame,
WindowUpdateFrame,
HeadersFrame,
GoAwayFrame,
DataFrame>;
// Print the HEX and ASCII representation of the buffer side by side
// 16 bytes per line. Same function as the xdd command in linux.
@ -173,165 +474,52 @@ static void dump_hex_beautiful(const void *ptr, size_t size)
}
}
static std::optional<SettingsFrame> parseSettingsFrame(ByteStream &payload)
{
if (payload.size() % 6 != 0)
{
LOG_ERROR << "Invalid settings frame length";
return std::nullopt;
}
SettingsFrame frame;
LOG_TRACE << "Settings frame:";
for (size_t i = 0; i < payload.size(); i += 6)
{
uint16_t key = payload.readU16BE();
uint32_t value = payload.readU32BE();
frame.settings.emplace_back(key, value);
LOG_TRACE << " key=" << key << " value=" << value;
}
return frame;
}
static std::optional<WindowUpdateFrame> parseWindowUpdateFrame(
ByteStream &payload)
{
if (payload.size() != 4)
{
LOG_ERROR << "Invalid window update frame length";
return std::nullopt;
}
WindowUpdateFrame frame;
// MSB is reserved for future use
frame.windowSizeIncrement = payload.readU32BE() & 0x7fffffff;
LOG_TRACE << "Window update frame: windowSizeIncrement="
<< frame.windowSizeIncrement;
return frame;
}
static std::optional<GoAwayFrame> parseGoAwayFrame(ByteStream &payload)
{
if (payload.size() < 8)
{
LOG_ERROR << "Invalid go away frame length";
return std::nullopt;
}
GoAwayFrame frame;
frame.lastStreamId = payload.readU32BE();
frame.errorCode = payload.readU32BE();
frame.additionalDebugData.resize(payload.remaining());
for (size_t i = 0; i < frame.additionalDebugData.size(); ++i)
frame.additionalDebugData[i] = payload.readU8();
LOG_TRACE << "Go away frame: lastStreamId=" << frame.lastStreamId
<< " errorCode=" << frame.errorCode
<< " additionalDebugData=" << frame.additionalDebugData.size();
return frame;
}
static std::optional<HeadersFrame> parseHeadersFrame(ByteStream &payload)
{
HeadersFrame frame;
frame.padLength = payload.readU8();
uint32_t streamDependency = payload.readU32BE();
frame.exclusive = streamDependency & (1U << 31);
frame.streamDependency = streamDependency & ((1U << 31) - 1);
frame.weight = payload.readU8();
frame.headerBlockFragment.resize(payload.remaining());
for (size_t i = 0; i < frame.headerBlockFragment.size(); ++i)
frame.headerBlockFragment[i] = payload.readU8();
// TODO: Handle padding
payload.skip(frame.padLength);
return frame;
}
static std::vector<uint8_t> serializeHeadsFrame(const HeadersFrame &frame)
{
std::vector<uint8_t> buffer;
buffer.reserve(6 + frame.headerBlockFragment.size() + frame.padLength);
// buffer.push_back(frame.padLength);
// uint32_t streamDependency = frame.streamDependency;
// if (frame.exclusive)
// streamDependency |= 1U << 31;
// buffer.push_back(streamDependency >> 24);
// buffer.push_back(streamDependency >> 16);
// buffer.push_back(streamDependency >> 8);
// buffer.push_back(streamDependency);
// buffer.push_back(frame.weight);
for (size_t i = 0; i < frame.headerBlockFragment.size(); ++i)
buffer.push_back(frame.headerBlockFragment[i]);
for (size_t i = 0; i < frame.padLength; ++i)
buffer.push_back(0x0);
return buffer;
}
static std::vector<uint8_t> serializeSettingsFrame(const SettingsFrame &frame)
{
if (frame.settings.size() == 0)
return std::vector<uint8_t>();
std::vector<uint8_t> buffer;
buffer.reserve(6 * frame.settings.size());
for (auto &[key, value] : frame.settings)
{
buffer.push_back(key >> 8);
buffer.push_back(key);
buffer.push_back(value >> 24);
buffer.push_back(value >> 16);
buffer.push_back(value >> 8);
buffer.push_back(value);
}
return buffer;
}
static std::vector<uint8_t> serializeWindowUpdateFrame(
const WindowUpdateFrame &frame)
{
std::vector<uint8_t> buffer;
buffer.reserve(4);
buffer.push_back(frame.windowSizeIncrement >> 24);
buffer.push_back(frame.windowSizeIncrement >> 16);
buffer.push_back(frame.windowSizeIncrement >> 8);
buffer.push_back(frame.windowSizeIncrement);
return buffer;
}
static std::vector<uint8_t> serializeFrame(const H2Frame &frame,
size_t streamId,
uint8_t flags = 0)
size_t streamId)
{
std::vector<uint8_t> buffer;
OByteStream buffer;
uint8_t type;
uint8_t flags;
bool ok = false;
if (std::holds_alternative<HeadersFrame>(frame))
{
const auto &f = std::get<HeadersFrame>(frame);
buffer = serializeHeadsFrame(f);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::Headers;
}
else if (std::holds_alternative<SettingsFrame>(frame))
{
const auto &f = std::get<SettingsFrame>(frame);
buffer = serializeSettingsFrame(f);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::Settings;
}
else if (std::holds_alternative<WindowUpdateFrame>(frame))
{
const auto &f = std::get<WindowUpdateFrame>(frame);
buffer = serializeWindowUpdateFrame(f);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::WindowUpdate;
}
else if (std::holds_alternative<GoAwayFrame>(frame))
{
const auto &f = std::get<GoAwayFrame>(frame);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::GoAway;
}
else
{
LOG_ERROR << "Unsupported frame type";
abort();
}
if (!ok)
{
LOG_ERROR << "Failed to serialize frame";
abort();
}
std::vector<uint8_t> full_frame;
full_frame.reserve(9 + buffer.size());
size_t length = buffer.size();
full_frame.reserve(9 + buffer.buffer.readableBytes());
size_t length = buffer.buffer.readableBytes();
assert(length <= 0xffffff);
full_frame.push_back(length >> 16);
full_frame.push_back(length >> 8);
@ -342,7 +530,9 @@ static std::vector<uint8_t> serializeFrame(const H2Frame &frame,
full_frame.push_back(streamId >> 16);
full_frame.push_back(streamId >> 8);
full_frame.push_back(streamId);
full_frame.insert(full_frame.end(), buffer.begin(), buffer.end());
full_frame.insert(full_frame.end(),
buffer.buffer.peek(),
buffer.buffer.peek() + length);
return full_frame;
}
@ -351,13 +541,13 @@ static std::vector<uint8_t> serializeFrame(const H2Frame &frame,
// We need to handle both cases. Also it could happen that the TCP stream
// just cuts off in the middle of a frame (or header). We need to handle that
// too.
static std::tuple<std::optional<H2Frame>, size_t, bool> parseH2Frame(
static std::tuple<std::optional<H2Frame>, size_t, uint8_t, bool> parseH2Frame(
trantor::MsgBuffer *msg)
{
if (msg->readableBytes() < 9)
{
LOG_TRACE << "Not enough bytes to parse H2 frame header";
return {std::nullopt, 0, false};
return {std::nullopt, 0, 0, false};
}
uint8_t *ptr = (uint8_t *)msg->peek();
@ -368,7 +558,7 @@ static std::tuple<std::optional<H2Frame>, size_t, bool> parseH2Frame(
if (msg->readableBytes() < length + 9)
{
LOG_TRACE << "Not enough bytes to parse H2 frame";
return {std::nullopt, 0, false};
return {std::nullopt, 0, 0, false};
}
const uint8_t type = header.readU8();
@ -380,7 +570,7 @@ static std::tuple<std::optional<H2Frame>, size_t, bool> parseH2Frame(
{
// TODO: Handle fatal protocol error
LOG_ERROR << "Invalid H2 frame type: " << (int)type;
return {std::nullopt, streamId, true};
return {std::nullopt, streamId, 0, true};
}
LOG_TRACE << "H2 frame: length=" << length << " type=" << (int)type
@ -389,18 +579,20 @@ static std::tuple<std::optional<H2Frame>, size_t, bool> parseH2Frame(
ByteStream payload(ptr + 9, length);
std::optional<H2Frame> frame;
if (type == (uint8_t)H2FrameType::Settings)
frame = parseSettingsFrame(payload);
frame = SettingsFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::WindowUpdate)
frame = parseWindowUpdateFrame(payload);
frame = WindowUpdateFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::GoAway)
frame = parseGoAwayFrame(payload);
frame = GoAwayFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::Headers)
frame = parseHeadersFrame(payload);
frame = HeadersFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::Data)
frame = DataFrame::parse(payload, flags);
else
{
LOG_WARN << "Unsupported H2 frame type: " << (int)type;
msg->retrieve(length + 9);
return {std::nullopt, streamId, false};
return {std::nullopt, streamId, 0, false};
}
if (payload.remaining() != 0)
@ -409,15 +601,15 @@ static std::tuple<std::optional<H2Frame>, size_t, bool> parseH2Frame(
msg->retrieve(length + 9);
if (!frame)
{
LOG_ERROR << "Failed to parse H2 frame";
return {std::nullopt, streamId, true};
LOG_ERROR << "Failed to parse H2 frame of type: " << (int)type;
return {std::nullopt, streamId, 0, true};
}
return {frame, streamId, false};
return {frame, streamId, flags, false};
}
void drogon::Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
HttpReqCallback &&callback)
void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
HttpReqCallback &&callback)
{
if (!serverSettingsReceived)
{
@ -425,23 +617,15 @@ void drogon::Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
return;
}
// HACK: Acknowledge the settings frame, move this somewhere appropriate
LOG_TRACE << "Acknowledge settings frame";
SettingsFrame settings;
auto sb = serializeFrame(settings, 0, 0);
dump_hex_beautiful(sb.data(), sb.size());
connPtr->send(sb.data(), sb.size());
sb = serializeFrame(settings, 0, 0x1);
dump_hex_beautiful(sb.data(), sb.size());
connPtr->send(sb.data(), sb.size());
const int32_t streamId = nextStreamId();
LOG_TRACE << "Sending HTTP/2 request: streamId=" << streamId;
if (streams.find(streamId) != streams.end())
{
LOG_FATAL << "Stream id already in use! This should not happen";
abort();
return;
}
WindowUpdateFrame windowUpdate;
windowUpdate.windowSizeIncrement = 200 * 1024 * 1024; // 200MB
auto wu = serializeFrame(windowUpdate, 0);
dump_hex_beautiful(wu.data(), wu.size());
connPtr->send(wu.data(), wu.size());
static hpack::HPacker hpack;
auto headers = req->headers();
HeadersFrame frame;
frame.padLength = 0;
@ -471,24 +655,25 @@ void drogon::Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
LOG_TRACE << "Final headers size: " << headersToEncode.size();
for (auto &[key, value] : headersToEncode)
LOG_TRACE << " " << key << ": " << value;
int n = hpack.encode(headersToEncode,
frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
int n = hpackTx.encode(headersToEncode,
frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
if (n < 0)
{
LOG_ERROR << "Failed to encode headers";
abort();
return;
}
LOG_TRACE << "Encoded headers size: " << n;
frame.headerBlockFragment.resize(n);
LOG_TRACE << "Encoded headers:";
dump_hex_beautiful(frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
frame.endHeaders = true;
frame.endStream = true;
LOG_TRACE << "Sending headers frame";
auto f = serializeFrame(frame, 1, 0x5);
auto f = serializeFrame(frame, streamId);
dump_hex_beautiful(f.data(), f.size());
connPtr->send(f.data(), f.size());
streams[streamId] = internal::H2Stream();
streams[streamId].callback = std::move(callback);
}
void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
@ -500,7 +685,10 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
dump_hex_beautiful(msg->peek(), msg->readableBytes());
while (true)
{
auto [frameOpt, streamId, error] = parseH2Frame(msg);
// FIXME: The code cannot distinguish between a out-of-data and
// unsupported frame type. We need to fix this as it should be handled
// differently.
auto [frameOpt, streamId, flags, error] = parseH2Frame(msg);
if (error && streamId == 0)
{
@ -550,11 +738,109 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
if (streamId == 0 && !serverSettingsReceived)
{
LOG_TRACE << "Server settings received. Sending our own "
"settings and WindowUpdate";
SettingsFrame settingsFrame;
settingsFrame.settings.emplace_back(
(uint16_t)H2SettingsKey::EnablePush, 0); // Disable push
auto b = serializeFrame(settingsFrame, 0);
connPtr->send((const char *)b.data(), b.size());
WindowUpdateFrame windowUpdateFrame;
// TODO: Keep track and update the window size
windowUpdateFrame.windowSizeIncrement = 200 * 1024 * 1024;
auto b2 = serializeFrame(windowUpdateFrame, 0);
connPtr->send((const char *)b2.data(), b2.size());
serverSettingsReceived = true;
for (auto &[req, cb] : bufferedRequests)
sendRequestInLoop(req, std::move(cb));
bufferedRequests.clear();
}
// Somehow nghttp2 wants us to send ACK after sending our
// preferences??
if (flags == 1)
continue;
LOG_TRACE << "Acknowledge settings frame";
SettingsFrame ackFrame;
ackFrame.ack = true;
auto b = serializeFrame(ackFrame, 0);
connPtr->send((const char *)b.data(), b.size());
}
else if (std::holds_alternative<HeadersFrame>(frame))
{
auto &f = std::get<HeadersFrame>(frame);
LOG_TRACE << "Headers frame received: size="
<< f.headerBlockFragment.size();
hpack::HPacker::KeyValueVector headers;
int n = hpackRx.decode(f.headerBlockFragment.data(),
f.headerBlockFragment.size(),
headers);
if (n < 0)
{
LOG_ERROR << "Failed to decode headers";
abort();
return;
}
for (auto &[key, value] : headers)
LOG_TRACE << " " << key << ": " << value;
auto it = streams.find(streamId);
if (it == streams.end())
{
LOG_ERROR << "Headers frame received for unknown stream id: "
<< streamId;
// TODO: Send GoAway frame
return;
}
it->second.response = std::make_shared<HttpResponseImpl>();
for (const auto &[key, value] : headers)
{
// TODO: Filter more pseudo headers
if (key == ":status")
continue;
// TODO: Validate content-length is either not present or
// the same as the body size sent by DATA frames
if (key == "content-length")
continue;
if (key == ":status")
{
// TODO: Validate status code
it->second.response->setStatusCode(
(drogon::HttpStatusCode)std::stoi(value));
continue;
}
it->second.response->addHeader(key, value);
}
}
else if (std::holds_alternative<DataFrame>(frame))
{
auto &f = std::get<DataFrame>(frame);
LOG_TRACE << "Data frame received: size=" << f.data.size();
auto it = streams.find(streamId);
if (it == streams.end())
{
LOG_ERROR << "Data frame received for unknown stream id: "
<< streamId;
return;
}
it->second.body.append((char *)f.data.data(), f.data.size());
if ((flags & (uint8_t)H2DataFlags::EndStream) != 0)
{
// TODO: Optmize setting body
std::string body(it->second.body.peek(),
it->second.body.readableBytes());
auto headers = it->second.response->headers();
it->second.response->setBody(std::move(body));
// FIXME: Store the actuall request object
auto req = HttpRequest::newHttpRequest();
respCallback(it->second.response,
{req, it->second.callback},
connPtr);
}
}
else if (std::holds_alternative<GoAwayFrame>(frame))
{

View File

@ -1,6 +1,7 @@
#pragma once
#include "HttpTransport.h"
#include "HttpResponseImpl.h"
// TOOD: Write our own HPACK implementation
#include "hpack/HPacker.h"
@ -14,10 +15,9 @@ namespace internal
// Defaults to stream 0 global properties
struct H2Stream
{
size_t maxConcurrentStreams = 100;
size_t initialWindowSize = 65535;
size_t maxFrameSize = 16384;
size_t avaliableWindowSize = 0;
HttpReqCallback callback;
HttpResponseImplPtr response;
trantor::MsgBuffer body;
};
} // namespace internal
@ -28,6 +28,13 @@ class Http2Transport : public HttpTransport
trantor::TcpConnectionPtr connPtr;
size_t *bytesSent_;
size_t *bytesReceived_;
hpack::HPacker hpackTx;
hpack::HPacker hpackRx;
std::priority_queue<int32_t> usibleStreamIds;
int32_t streamIdTop = 1;
std::unordered_map<int32_t, internal::H2Stream> streams;
// TODO: Handle server-initiated stream creation
// HTTP/2 client-wide settings
size_t maxConcurrentStreams = 100;
@ -39,6 +46,38 @@ class Http2Transport : public HttpTransport
bool serverSettingsReceived = false;
std::vector<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
int32_t nextStreamId()
{
if (usibleStreamIds.empty())
{
int32_t id = streamIdTop;
streamIdTop += 2;
return id;
}
int32_t id = usibleStreamIds.top();
usibleStreamIds.pop();
return id;
}
void retireStreamId(int32_t id)
{
if (id == streamIdTop - 2)
{
streamIdTop -= 2;
while (!usibleStreamIds.empty() &&
usibleStreamIds.top() == streamIdTop - 2)
{
usibleStreamIds.pop();
streamIdTop -= 2;
assert(streamIdTop >= 1);
}
}
else
{
usibleStreamIds.push(id);
}
}
public:
Http2Transport(trantor::TcpConnectionPtr connPtr,
size_t *bytesSent,
@ -54,7 +93,7 @@ class Http2Transport : public HttpTransport
return 0;
}
bool handleConnectionClose()
bool handleConnectionClose() override
{
throw std::runtime_error(
"HTTP/2 handleConnectionClose not implemented");

View File

@ -59,10 +59,10 @@ class HttpAppFrameworkImpl final : public HttpAppFramework
PluginBase *getPlugin(const std::string &name) override;
std::shared_ptr<PluginBase> getSharedPlugin(
const std::string &name) override;
void addPlugins(const Json::Value &configs);
void addPlugins(const Json::Value &configs) override;
void addPlugin(const std::string &name,
const std::vector<std::string> &dependencies,
const Json::Value &config);
const Json::Value &config) override;
HttpAppFramework &addListener(
const std::string &ip,
uint16_t port,

View File

@ -55,7 +55,7 @@ class Http1xTransport : public HttpTransport
return pipeliningCallbacks_.size();
}
bool handleConnectionClose();
bool handleConnectionClose() override;
void onError(ReqResult result) override
{

View File

@ -149,7 +149,7 @@ class HttpRequestImpl : public HttpRequest
return routingParams_;
}
void setRoutingParameters(std::vector<std::string> &&params)
void setRoutingParameters(std::vector<std::string> &&params) override
{
routingParams_ = std::move(params);
}

View File

@ -383,7 +383,7 @@ class DROGON_EXPORT HttpResponseImpl : public HttpResponse
utils::gzipDecompress(bodyPtr_->data(), bodyPtr_->length());
removeHeaderBy("content-encoding");
bodyPtr_ =
std::make_shared<HttpMessageStringBody>(move(gunzipBody));
std::make_shared<HttpMessageStringBody>(std::move(gunzipBody));
addHeader("content-length", std::to_string(bodyPtr_->length()));
}
}
@ -396,7 +396,7 @@ class DROGON_EXPORT HttpResponseImpl : public HttpResponse
utils::brotliDecompress(bodyPtr_->data(), bodyPtr_->length());
removeHeaderBy("content-encoding");
bodyPtr_ =
std::make_shared<HttpMessageStringBody>(move(gunzipBody));
std::make_shared<HttpMessageStringBody>(std::move(gunzipBody));
addHeader("content-length", std::to_string(bodyPtr_->length()));
}
}