mirror of
https://github.com/drogonframework/drogon.git
synced 2025-07-23 00:01:23 -04:00
Compare commits
16 Commits
d50577b709
...
52d0fdd25d
Author | SHA1 | Date | |
---|---|---|---|
|
52d0fdd25d | ||
|
0eb2cdabe7 | ||
|
e12260a0b7 | ||
|
c82e8208ab | ||
|
e13a8b930f | ||
|
4a2eecf03d | ||
|
45a2b1d0d3 | ||
|
d69699ceb2 | ||
|
5f52a80358 | ||
|
ab566b3524 | ||
|
e565f38d7a | ||
|
5d434577ff | ||
|
2eef512537 | ||
|
847b580cf5 | ||
|
83606eb5a6 | ||
|
f5c4863ad0 |
@ -16,7 +16,7 @@ int main()
|
||||
{
|
||||
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
|
||||
{
|
||||
auto client = HttpClient::newHttpClient("https://clehaxze.tw",
|
||||
auto client = HttpClient::newHttpClient("https://clehaxze.tw:8844",
|
||||
nullptr,
|
||||
false,
|
||||
false);
|
||||
|
@ -18,7 +18,11 @@ static std::optional<size_t> stosz(const std::string &str)
|
||||
{
|
||||
try
|
||||
{
|
||||
return std::stoull(str);
|
||||
size_t idx = 0;
|
||||
size_t res = std::stoull(str, &idx);
|
||||
if (idx != str.size())
|
||||
return std::nullopt;
|
||||
return res;
|
||||
}
|
||||
catch (const std::exception &)
|
||||
{
|
||||
@ -107,7 +111,7 @@ struct ByteStream
|
||||
|
||||
uint32_t readU24BE()
|
||||
{
|
||||
assert(offset <= length - 3);
|
||||
assert(length >= 3 && offset <= length - 3);
|
||||
uint32_t res =
|
||||
ptr[offset] << 16 | ptr[offset + 1] << 8 | ptr[offset + 2];
|
||||
offset += 3;
|
||||
@ -116,7 +120,7 @@ struct ByteStream
|
||||
|
||||
uint32_t readU32BE()
|
||||
{
|
||||
assert(offset <= length - 4);
|
||||
assert(length >= 4 && offset <= length - 4);
|
||||
uint32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
|
||||
ptr[offset + 2] << 8 | ptr[offset + 3];
|
||||
offset += 4;
|
||||
@ -125,7 +129,7 @@ struct ByteStream
|
||||
|
||||
std::pair<bool, int32_t> readBI32BE()
|
||||
{
|
||||
assert(offset <= length - 4);
|
||||
assert(length >= 4 && offset <= length - 4);
|
||||
int32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
|
||||
ptr[offset + 2] << 8 | ptr[offset + 3];
|
||||
offset += 4;
|
||||
@ -137,7 +141,7 @@ struct ByteStream
|
||||
|
||||
uint16_t readU16BE()
|
||||
{
|
||||
assert(offset <= length - 2);
|
||||
assert(length >= 2 && offset <= length - 2);
|
||||
uint16_t res = ptr[offset] << 8 | ptr[offset + 1];
|
||||
offset += 2;
|
||||
return res;
|
||||
@ -145,13 +149,13 @@ struct ByteStream
|
||||
|
||||
uint8_t readU8()
|
||||
{
|
||||
assert(offset <= length - 1);
|
||||
assert(length >= 1 && offset <= length - 1);
|
||||
return ptr[offset++];
|
||||
}
|
||||
|
||||
void read(uint8_t *buffer, size_t size)
|
||||
{
|
||||
assert(offset <= length - size || size == 0);
|
||||
assert((length >= size && offset <= length - size) || size == 0);
|
||||
memcpy(buffer, ptr + offset, size);
|
||||
offset += size;
|
||||
}
|
||||
@ -171,7 +175,7 @@ struct ByteStream
|
||||
|
||||
void skip(size_t n)
|
||||
{
|
||||
assert(offset <= length - n || n == 0);
|
||||
assert((length >= n && offset <= length - n) || n == 0);
|
||||
offset += n;
|
||||
}
|
||||
|
||||
@ -227,6 +231,7 @@ struct OByteStream
|
||||
{
|
||||
assert(value <= 0xffffff);
|
||||
assert(offset <= buffer.readableBytes() - 3);
|
||||
assert(buffer.writableBytes() >= 3);
|
||||
auto ptr = (uint8_t *)buffer.peek() + offset;
|
||||
ptr[0] = value >> 16;
|
||||
ptr[1] = value >> 8;
|
||||
@ -236,6 +241,7 @@ struct OByteStream
|
||||
void overwriteU8(size_t offset, uint8_t value)
|
||||
{
|
||||
assert(offset <= buffer.readableBytes() - 1);
|
||||
assert(buffer.writableBytes() >= 1);
|
||||
auto ptr = (uint8_t *)buffer.peek() + offset;
|
||||
ptr[0] = value;
|
||||
}
|
||||
@ -325,11 +331,25 @@ std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
|
||||
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
|
||||
bool priority = flags & (uint8_t)H2HeadersFlags::Priority;
|
||||
|
||||
if (payload.size() == 0)
|
||||
{
|
||||
LOG_TRACE << "Header size cannot be 0";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
HeadersFrame frame;
|
||||
if (padded)
|
||||
{
|
||||
frame.padLength = payload.readU8();
|
||||
}
|
||||
|
||||
size_t minSize = frame.padLength + (priority ? 5 : 0);
|
||||
if (payload.size() < minSize)
|
||||
{
|
||||
LOG_TRACE << "Invalid headers frame length";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (priority)
|
||||
{
|
||||
auto [exclusive, streamDependency] = payload.readBI32BE();
|
||||
@ -346,6 +366,7 @@ std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
|
||||
frame.endStream = true;
|
||||
}
|
||||
|
||||
assert(payload.remaining() >= frame.padLength);
|
||||
int64_t payloadSize = payload.remaining() - frame.padLength;
|
||||
if (payloadSize < 0)
|
||||
{
|
||||
@ -421,11 +442,20 @@ std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
|
||||
{
|
||||
frame.padLength = payload.readU8();
|
||||
}
|
||||
|
||||
size_t minSize = frame.padLength;
|
||||
if (payload.size() < minSize)
|
||||
{
|
||||
LOG_TRACE << "Invalid data frame length";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (endStream)
|
||||
{
|
||||
frame.endStream = true;
|
||||
}
|
||||
|
||||
assert(payload.remaining() >= frame.padLength);
|
||||
size_t payloadSize = payload.remaining() - frame.padLength;
|
||||
if (payloadSize < 0)
|
||||
{
|
||||
@ -447,7 +477,7 @@ std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
|
||||
|
||||
bool DataFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||
{
|
||||
flags = 0x0;
|
||||
flags = (endStream ? (uint8_t)H2DataFlags::EndStream : 0x0);
|
||||
stream.write(data.data(), data.size());
|
||||
if (padLength > 0)
|
||||
{
|
||||
@ -477,6 +507,95 @@ bool PingFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<ContinuationFrame> ContinuationFrame::parse(ByteStream &payload,
|
||||
uint8_t flags)
|
||||
{
|
||||
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
|
||||
ContinuationFrame frame;
|
||||
if (endHeaders)
|
||||
{
|
||||
frame.endHeaders = true;
|
||||
}
|
||||
|
||||
frame.headerBlockFragment.resize(payload.remaining());
|
||||
payload.read(frame.headerBlockFragment.data(),
|
||||
frame.headerBlockFragment.size());
|
||||
return frame;
|
||||
}
|
||||
|
||||
bool ContinuationFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||
{
|
||||
flags = 0x0;
|
||||
if (endHeaders)
|
||||
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
|
||||
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<PushPromiseFrame> PushPromiseFrame::parse(ByteStream &payload,
|
||||
uint8_t flags)
|
||||
{
|
||||
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
|
||||
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
|
||||
|
||||
PushPromiseFrame frame;
|
||||
if (padded)
|
||||
frame.padLength = payload.readU8();
|
||||
if (endHeaders)
|
||||
frame.endHeaders = true;
|
||||
|
||||
size_t minSize = frame.padLength + 4;
|
||||
if (payload.size() < minSize)
|
||||
{
|
||||
LOG_TRACE << "Invalid push promise frame length";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto [_, promisedStreamId] = payload.readBI32BE();
|
||||
frame.promisedStreamId = promisedStreamId;
|
||||
assert(payload.remaining() >= frame.padLength);
|
||||
frame.headerBlockFragment.resize(payload.remaining() - frame.padLength);
|
||||
payload.read(frame.headerBlockFragment.data(),
|
||||
frame.headerBlockFragment.size());
|
||||
payload.skip(frame.padLength);
|
||||
return frame;
|
||||
}
|
||||
|
||||
bool PushPromiseFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||
{
|
||||
flags = 0x0;
|
||||
if (endHeaders)
|
||||
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
|
||||
assert(promisedStreamId > 0);
|
||||
stream.writeU32BE(promisedStreamId);
|
||||
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
|
||||
if (padLength > 0)
|
||||
{
|
||||
flags |= (uint8_t)H2HeadersFlags::Padded;
|
||||
stream.writeU8(padLength);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<RstStreamFrame> RstStreamFrame::parse(ByteStream &payload,
|
||||
uint8_t flags)
|
||||
{
|
||||
if (payload.size() != 4)
|
||||
{
|
||||
LOG_TRACE << "Invalid RST_STREAM frame length";
|
||||
return std::nullopt;
|
||||
}
|
||||
RstStreamFrame frame;
|
||||
frame.errorCode = payload.readU32BE();
|
||||
return frame;
|
||||
}
|
||||
|
||||
bool RstStreamFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||
{
|
||||
flags = 0x0;
|
||||
stream.writeU32BE(errorCode);
|
||||
return true;
|
||||
}
|
||||
} // namespace drogon::internal
|
||||
|
||||
// Print the HEX and ASCII representation of the buffer side by side
|
||||
@ -568,6 +687,24 @@ static trantor::MsgBuffer serializeFrame(const H2Frame &frame, int32_t streamId)
|
||||
ok = f.serialize(buffer, flags);
|
||||
type = (uint8_t)H2FrameType::Ping;
|
||||
}
|
||||
else if (std::holds_alternative<ContinuationFrame>(frame))
|
||||
{
|
||||
const auto &f = std::get<ContinuationFrame>(frame);
|
||||
ok = f.serialize(buffer, flags);
|
||||
type = (uint8_t)H2FrameType::Continuation;
|
||||
}
|
||||
else if (std::holds_alternative<PushPromiseFrame>(frame))
|
||||
{
|
||||
const auto &f = std::get<PushPromiseFrame>(frame);
|
||||
ok = f.serialize(buffer, flags);
|
||||
type = (uint8_t)H2FrameType::PushPromise;
|
||||
}
|
||||
else if (std::holds_alternative<RstStreamFrame>(frame))
|
||||
{
|
||||
const auto &f = std::get<RstStreamFrame>(frame);
|
||||
ok = f.serialize(buffer, flags);
|
||||
type = (uint8_t)H2FrameType::RstStream;
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_ERROR << "Unsupported frame type";
|
||||
@ -646,6 +783,12 @@ static std::tuple<std::optional<H2Frame>, uint32_t, uint8_t, bool> parseH2Frame(
|
||||
frame = DataFrame::parse(payload, flags);
|
||||
else if (type == (uint8_t)H2FrameType::Ping)
|
||||
frame = PingFrame::parse(payload, flags);
|
||||
else if (type == (uint8_t)H2FrameType::Continuation)
|
||||
frame = ContinuationFrame::parse(payload, flags);
|
||||
else if (type == (uint8_t)H2FrameType::PushPromise)
|
||||
frame = PushPromiseFrame::parse(payload, flags);
|
||||
else if (type == (uint8_t)H2FrameType::RstStream)
|
||||
frame = RstStreamFrame::parse(payload, flags);
|
||||
else
|
||||
{
|
||||
LOG_WARN << "Unsupported H2 frame type: " << (int)type;
|
||||
@ -670,13 +813,7 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
||||
HttpReqCallback &&callback)
|
||||
{
|
||||
connPtr->getLoop()->assertInLoopThread();
|
||||
if (!serverSettingsReceived)
|
||||
{
|
||||
bufferedRequests.push({req, std::move(callback)});
|
||||
return;
|
||||
}
|
||||
|
||||
if (streams.size() >= maxConcurrentStreams)
|
||||
if (streams.size() + 1 >= maxConcurrentStreams)
|
||||
{
|
||||
LOG_TRACE << "Too many streams in flight. Buffering request";
|
||||
bufferedRequests.push({req, std::move(callback)});
|
||||
@ -697,15 +834,11 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
||||
}
|
||||
|
||||
auto headers = req->headers();
|
||||
HeadersFrame frame;
|
||||
frame.padLength = 0;
|
||||
frame.exclusive = false;
|
||||
frame.streamDependency = 0;
|
||||
frame.weight = 0;
|
||||
frame.headerBlockFragment.resize(maxCompressiedHeaderSize);
|
||||
std::vector<uint8_t> encodedHeaders(maxCompressiedHeaderSize);
|
||||
|
||||
LOG_TRACE << "Sending HTTP/2 headers: size=" << headers.size();
|
||||
hpack::HPacker::KeyValueVector headersToEncode;
|
||||
headersToEncode.reserve(headers.size() + 5);
|
||||
const std::array<std::string_view, 2> headersToSkip = {
|
||||
{"host", "connection"}};
|
||||
headersToEncode.emplace_back(":method", req->methodString());
|
||||
@ -726,8 +859,8 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
||||
for (auto &[key, value] : headersToEncode)
|
||||
LOG_TRACE << " " << key << ": " << value;
|
||||
int n = hpackTx.encode(headersToEncode,
|
||||
frame.headerBlockFragment.data(),
|
||||
frame.headerBlockFragment.size());
|
||||
encodedHeaders.data(),
|
||||
encodedHeaders.size());
|
||||
if (n < 0)
|
||||
{
|
||||
LOG_TRACE << "Failed to encode headers. Internal error or header "
|
||||
@ -742,17 +875,54 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
||||
abort();
|
||||
return;
|
||||
}
|
||||
frame.headerBlockFragment.resize(n);
|
||||
frame.endHeaders = true;
|
||||
|
||||
encodedHeaders.resize(n);
|
||||
LOG_TRACE << "Encoded headers size: " << encodedHeaders.size();
|
||||
|
||||
bool haveBody = req->body().length() > 0;
|
||||
auto &stream = createStream(streamId);
|
||||
if (req->body().length() == 0)
|
||||
frame.endStream = true;
|
||||
LOG_TRACE << "Sending headers frame";
|
||||
auto f = serializeFrame(frame, streamId);
|
||||
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
|
||||
connPtr->send(f);
|
||||
bool needsContinuation = encodedHeaders.size() > maxFrameSize;
|
||||
for (size_t i = 0; i < encodedHeaders.size(); i += maxFrameSize)
|
||||
{
|
||||
bool isFirst = i == 0;
|
||||
bool isLast = i + maxFrameSize >= encodedHeaders.size();
|
||||
size_t dataSize = (std::min)(maxFrameSize, encodedHeaders.size() - i);
|
||||
|
||||
auto frame = [&]() -> H2Frame {
|
||||
if (isFirst)
|
||||
{
|
||||
HeadersFrame frame;
|
||||
frame.headerBlockFragment.resize(dataSize);
|
||||
memcpy(frame.headerBlockFragment.data(),
|
||||
encodedHeaders.data() + i,
|
||||
dataSize);
|
||||
frame.endHeaders = isLast;
|
||||
frame.endStream = (!haveBody && isLast);
|
||||
return frame;
|
||||
}
|
||||
ContinuationFrame frame;
|
||||
frame.headerBlockFragment.resize(dataSize);
|
||||
assert(encodedHeaders.size() > i + dataSize);
|
||||
memcpy(frame.headerBlockFragment.data(),
|
||||
encodedHeaders.data() + i,
|
||||
dataSize);
|
||||
frame.endHeaders = (!haveBody && isLast);
|
||||
return frame;
|
||||
}();
|
||||
|
||||
auto f = serializeFrame(frame, streamId);
|
||||
LOG_TRACE << "Sending " << (isFirst ? "HEADERS" : "CONTINUATION")
|
||||
<< " frame:";
|
||||
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
|
||||
connPtr->send(serializeFrame(frame, streamId));
|
||||
}
|
||||
if (needsContinuation && !haveBody)
|
||||
{
|
||||
DataFrame frame;
|
||||
frame.endStream = true;
|
||||
connPtr->send(serializeFrame(frame, streamId));
|
||||
return;
|
||||
}
|
||||
stream.callback = std::move(callback);
|
||||
stream.request = req;
|
||||
|
||||
@ -763,29 +933,41 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
||||
return;
|
||||
}
|
||||
|
||||
if (req->body().length() > stream.avaliableTxWindow)
|
||||
{
|
||||
LOG_ERROR << "HTTP/2 body too large to fit in INITIAL_WINDOW_SIZE. Not "
|
||||
"supported yet.";
|
||||
abort();
|
||||
return;
|
||||
}
|
||||
size_t bodySize = req->body().length();
|
||||
bool sendEverything =
|
||||
bodySize <= stream.avaliableTxWindow && bodySize <= avaliableTxWindow;
|
||||
size_t maxSendSize = bodySize;
|
||||
maxSendSize = (std::min)(maxSendSize, stream.avaliableTxWindow);
|
||||
maxSendSize = (std::min)(maxSendSize, avaliableTxWindow);
|
||||
|
||||
DataFrame dataFrame;
|
||||
for (size_t i = 0; i < req->body().length(); i += maxFrameSize)
|
||||
size_t i;
|
||||
for (i = 0; i < maxSendSize; i += maxFrameSize)
|
||||
{
|
||||
size_t readSize = (std::min)(maxFrameSize, req->body().length() - i);
|
||||
size_t readSize = (std::min)(maxFrameSize, bodySize - i);
|
||||
std::vector<uint8_t> buffer;
|
||||
buffer.resize(readSize);
|
||||
memcpy(buffer.data(), req->body().data() + i, readSize);
|
||||
dataFrame.data = std::move(buffer);
|
||||
dataFrame.endStream = (i + maxFrameSize >= req->body().length());
|
||||
dataFrame.endStream = (i + maxFrameSize >= bodySize);
|
||||
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
|
||||
<< " endStream=" << dataFrame.endStream;
|
||||
connPtr->send(serializeFrame(dataFrame, streamId));
|
||||
|
||||
stream.avaliableTxWindow -= dataFrame.data.size();
|
||||
avaliableRxWindow -= dataFrame.data.size();
|
||||
avaliableTxWindow -= dataFrame.data.size();
|
||||
}
|
||||
|
||||
if (!sendEverything)
|
||||
{
|
||||
auto it = pendingDataSend.find(streamId);
|
||||
if (it != pendingDataSend.end())
|
||||
{
|
||||
LOG_FATAL << "Stream id already in use! This should not happen";
|
||||
abort();
|
||||
}
|
||||
|
||||
pendingDataSend.emplace(streamId, i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -806,6 +988,9 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
avaliableRxWindow += windowIncreaseSize;
|
||||
}
|
||||
|
||||
if (msg->readableBytes() == 0)
|
||||
break;
|
||||
|
||||
// FIXME: The code cannot distinguish between a out-of-data and
|
||||
// unsupported frame type. We need to fix this as it should be handled
|
||||
// differently.
|
||||
@ -825,6 +1010,7 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
}
|
||||
auto &frame = *frameOpt;
|
||||
|
||||
// special case for PING and GOAWAY. These are all global frames
|
||||
if (std::holds_alternative<GoAwayFrame>(frame))
|
||||
{
|
||||
auto &f = std::get<GoAwayFrame>(frame);
|
||||
@ -840,18 +1026,13 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
{
|
||||
if (streamId > f.lastStreamId)
|
||||
{
|
||||
streamFinished(streamId,
|
||||
ReqResult::BadResponse,
|
||||
StreamCloseErrorCode::RefusedStream);
|
||||
responseErrored(streamId, ReqResult::BadResponse);
|
||||
}
|
||||
}
|
||||
// TODO: Should be half-closed but transport doesn't support it yet
|
||||
connPtr->shutdown();
|
||||
}
|
||||
|
||||
// special case for PING frame. It is the only frame that is not
|
||||
// associated with a stream
|
||||
if (std::holds_alternative<PingFrame>(frame))
|
||||
else if (std::holds_alternative<PingFrame>(frame))
|
||||
{
|
||||
auto &f = std::get<PingFrame>(frame);
|
||||
if (f.ack)
|
||||
@ -867,6 +1048,32 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we are expecting a CONTINUATION frame, we should not receive
|
||||
// HEADERS or CONTINUATION from other streams
|
||||
if (expectngContinuationStreamId != 0 &&
|
||||
(std::holds_alternative<HeadersFrame>(frame) ||
|
||||
(std::holds_alternative<ContinuationFrame>(frame) &&
|
||||
streamId != expectngContinuationStreamId)))
|
||||
{
|
||||
LOG_TRACE << "Protocol error: unexpected HEADERS or "
|
||||
"CONTINUATION frame";
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"Expecting CONTINUATION frame for stream " +
|
||||
std::to_string(expectngContinuationStreamId));
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::holds_alternative<PushPromiseFrame>(frame))
|
||||
{
|
||||
LOG_TRACE << "Push promise frame received. Not supported yet. "
|
||||
"Connection will die";
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"Push promise not supported");
|
||||
return;
|
||||
}
|
||||
|
||||
if (streamId != 0)
|
||||
{
|
||||
handleFrameForStream(frame, streamId, flags);
|
||||
@ -878,6 +1085,14 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
{
|
||||
auto &f = std::get<WindowUpdateFrame>(frame);
|
||||
avaliableTxWindow += f.windowSizeIncrement;
|
||||
|
||||
// HACK: Notify stream we have more window size available
|
||||
auto it = pendingDataSend.begin();
|
||||
if (it == pendingDataSend.end())
|
||||
continue;
|
||||
auto hackFrame = f;
|
||||
hackFrame.windowSizeIncrement = 0;
|
||||
handleFrameForStream(hackFrame, it->first, 0);
|
||||
}
|
||||
else if (std::holds_alternative<SettingsFrame>(frame))
|
||||
{
|
||||
@ -888,7 +1103,7 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
{
|
||||
hpackRx.setMaxTableSize(value);
|
||||
}
|
||||
if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
|
||||
else if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
|
||||
{
|
||||
// Note: MAX_CONCURRENT_STREAMS can be 0, which means
|
||||
// the client is not allowed to send any request. I doubt
|
||||
@ -927,27 +1142,6 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
LOG_TRACE << "Unsupported settings key: " << key;
|
||||
}
|
||||
}
|
||||
|
||||
if (!serverSettingsReceived)
|
||||
{
|
||||
LOG_TRACE
|
||||
<< "Server settings received. Sending our own WindowUpdate";
|
||||
|
||||
WindowUpdateFrame windowUpdateFrame;
|
||||
windowUpdateFrame.windowSizeIncrement = windowIncreaseSize;
|
||||
connPtr->send(serializeFrame(windowUpdateFrame, 0));
|
||||
avaliableRxWindow = initialWindowSize;
|
||||
|
||||
serverSettingsReceived = true;
|
||||
while (!bufferedRequests.empty() &&
|
||||
streams.size() < maxConcurrentStreams)
|
||||
{
|
||||
auto &[req, cb] = bufferedRequests.front();
|
||||
sendRequestInLoop(req, std::move(cb));
|
||||
bufferedRequests.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Somehow nghttp2 wants us to send ACK after sending our
|
||||
// preferences??
|
||||
if (flags == 1)
|
||||
@ -961,12 +1155,30 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
{
|
||||
// Should never show up on stream 0
|
||||
LOG_FATAL << "Protocol error: HEADERS frame on stream 0";
|
||||
errorCallback(ReqResult::BadResponse);
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"HEADERS frame on stream 0");
|
||||
}
|
||||
else if (std::holds_alternative<DataFrame>(frame))
|
||||
{
|
||||
LOG_FATAL << "Protocol error: DATA frame on stream 0";
|
||||
errorCallback(ReqResult::BadResponse);
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"DATA frame on stream 0");
|
||||
}
|
||||
else if (std::holds_alternative<ContinuationFrame>(frame))
|
||||
{
|
||||
LOG_FATAL << "Protocol error: CONTINUATION frame on stream 0";
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"CONTINUATION frame on stream 0");
|
||||
}
|
||||
else if (std::holds_alternative<RstStreamFrame>(frame))
|
||||
{
|
||||
LOG_FATAL << "Protocol error: RST_STREAM frame on stream 0";
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"RST_STREAM frame on stream 0");
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -975,6 +1187,65 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
||||
}
|
||||
}
|
||||
|
||||
bool Http2Transport::parseAndApplyHeaders(internal::H2Stream &stream,
|
||||
const void *data,
|
||||
size_t size)
|
||||
{
|
||||
hpack::HPacker::KeyValueVector headers;
|
||||
int n = hpackRx.decode((const uint8_t *)data, size, headers);
|
||||
auto streamId = stream.streamId;
|
||||
if (n < 0)
|
||||
{
|
||||
LOG_TRACE << "Failed to decode headers";
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::CompressionError,
|
||||
"Failed to decode headers");
|
||||
return false;
|
||||
}
|
||||
for (auto &[key, value] : headers)
|
||||
LOG_TRACE << " " << key << ": " << value;
|
||||
assert(stream.response == nullptr);
|
||||
stream.response = std::make_shared<HttpResponseImpl>();
|
||||
for (const auto &[key, value] : headers)
|
||||
{
|
||||
// TODO: Filter more pseudo headers
|
||||
if (key == "content-length")
|
||||
{
|
||||
auto sz = stosz(value);
|
||||
if (!sz)
|
||||
{
|
||||
responseErrored(streamId, ReqResult::BadResponse);
|
||||
return false;
|
||||
}
|
||||
stream.contentLength = std::move(sz);
|
||||
}
|
||||
if (key == ":status")
|
||||
{
|
||||
auto status = stosz(value);
|
||||
if (!status)
|
||||
{
|
||||
responseErrored(streamId, ReqResult::BadResponse);
|
||||
return false;
|
||||
}
|
||||
// TODO: Validate status code
|
||||
stream.response->setStatusCode((HttpStatusCode)*status);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Anti request smuggling. We look for \r or \n in the header
|
||||
// name or value. If we find one, we abort the stream.
|
||||
if (key.find_first_of("\r\n") != std::string::npos ||
|
||||
value.find_first_of("\r\n") != std::string::npos)
|
||||
{
|
||||
responseErrored(streamId, ReqResult::BadResponse);
|
||||
return false;
|
||||
}
|
||||
|
||||
stream.response->addHeader(key, value);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Http2Transport::Http2Transport(trantor::TcpConnectionPtr connPtr,
|
||||
size_t *bytesSent,
|
||||
size_t *bytesReceived)
|
||||
@ -1011,6 +1282,16 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
||||
|
||||
if (std::holds_alternative<HeadersFrame>(frame))
|
||||
{
|
||||
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
|
||||
{
|
||||
auto &f = std::get<HeadersFrame>(frame);
|
||||
headerBufferRx.append((char *)f.headerBlockFragment.data(),
|
||||
f.headerBlockFragment.size());
|
||||
expectngContinuationStreamId = streamId;
|
||||
stream.state = StreamState::ExpectingContinuation;
|
||||
return;
|
||||
}
|
||||
|
||||
if (stream.state != StreamState::ExpectingHeaders)
|
||||
{
|
||||
killConnection(streamId,
|
||||
@ -1019,79 +1300,51 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
||||
return;
|
||||
}
|
||||
auto &f = std::get<HeadersFrame>(frame);
|
||||
LOG_TRACE << "Headers frame received: size="
|
||||
<< f.headerBlockFragment.size();
|
||||
hpack::HPacker::KeyValueVector headers;
|
||||
int n = hpackRx.decode(f.headerBlockFragment.data(),
|
||||
f.headerBlockFragment.size(),
|
||||
headers);
|
||||
if (n < 0)
|
||||
{
|
||||
LOG_TRACE << "Failed to decode headers";
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::CompressionError,
|
||||
"Failed to decode headers");
|
||||
// This function handles error itself
|
||||
if (!parseAndApplyHeaders(stream,
|
||||
f.headerBlockFragment.data(),
|
||||
f.headerBlockFragment.size()))
|
||||
return;
|
||||
}
|
||||
for (auto &[key, value] : headers)
|
||||
LOG_TRACE << " " << key << ": " << value;
|
||||
it->second.response = std::make_shared<HttpResponseImpl>();
|
||||
for (const auto &[key, value] : headers)
|
||||
{
|
||||
// TODO: Filter more pseudo headers
|
||||
if (key == "content-length")
|
||||
{
|
||||
auto sz = stosz(value);
|
||||
if (!sz)
|
||||
{
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"Invalid content-length");
|
||||
return;
|
||||
}
|
||||
it->second.contentLength = std::move(sz);
|
||||
}
|
||||
if (key == ":status")
|
||||
{
|
||||
// TODO: Validate status code
|
||||
it->second.response->setStatusCode(
|
||||
(drogon::HttpStatusCode)std::stoi(value));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Anti request smuggling. We look for \r or \n in the header
|
||||
// name or value. If we find one, we abort the stream.
|
||||
if (key.find_first_of("\r\n") != std::string::npos ||
|
||||
value.find_first_of("\r\n") != std::string::npos)
|
||||
{
|
||||
streamFinished(streamId,
|
||||
ReqResult::BadResponse,
|
||||
StreamCloseErrorCode::ProtocolError);
|
||||
return;
|
||||
}
|
||||
|
||||
it->second.response->addHeader(key, value);
|
||||
}
|
||||
|
||||
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
|
||||
{
|
||||
LOG_TRACE << "We don't support CONTINUATION frames yet!";
|
||||
stream.state = StreamState::ExpectingContinuation;
|
||||
abort();
|
||||
}
|
||||
// There is no body in the response.
|
||||
if ((flags & (uint8_t)H2HeadersFlags::EndStream))
|
||||
{
|
||||
stream.state = StreamState::Finished;
|
||||
streamFinished(stream);
|
||||
responseSuccess(stream);
|
||||
return;
|
||||
}
|
||||
stream.state = StreamState::ExpectingData;
|
||||
}
|
||||
else if (std::holds_alternative<ContinuationFrame>(frame))
|
||||
{
|
||||
auto &f = std::get<ContinuationFrame>(frame);
|
||||
if (stream.state != StreamState::ExpectingContinuation)
|
||||
{
|
||||
killConnection(streamId,
|
||||
StreamCloseErrorCode::ProtocolError,
|
||||
"Unexpected continuation frame");
|
||||
return;
|
||||
}
|
||||
|
||||
headerBufferRx.append((char *)f.headerBlockFragment.data(),
|
||||
f.headerBlockFragment.size());
|
||||
bool endHeaders = (flags & (uint8_t)H2HeadersFlags::EndHeaders) != 0;
|
||||
if (endHeaders)
|
||||
{
|
||||
stream.state = StreamState::ExpectingData;
|
||||
expectngContinuationStreamId = 0;
|
||||
bool ok = parseAndApplyHeaders(stream,
|
||||
headerBufferRx.peek(),
|
||||
headerBufferRx.readableBytes());
|
||||
headerBufferRx.retrieveAll();
|
||||
if (!ok)
|
||||
LOG_TRACE << "Failed to parse headers in continuation frame";
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if (std::holds_alternative<DataFrame>(frame))
|
||||
{
|
||||
auto &f = std::get<DataFrame>(frame);
|
||||
// TODO: Make sure this logic fits RFC
|
||||
if (avaliableRxWindow < f.data.size())
|
||||
{
|
||||
killConnection(streamId,
|
||||
@ -1126,15 +1379,13 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
||||
stream.body.readableBytes() != *stream.contentLength)
|
||||
{
|
||||
LOG_TRACE << "Content-length mismatch";
|
||||
streamFinished(streamId,
|
||||
ReqResult::BadResponse,
|
||||
StreamCloseErrorCode::ProtocolError);
|
||||
responseErrored(streamId, ReqResult::BadResponse);
|
||||
return;
|
||||
}
|
||||
// TODO: Optmize setting body
|
||||
std::string body(stream.body.peek(), stream.body.readableBytes());
|
||||
stream.response->setBody(std::move(body));
|
||||
streamFinished(stream);
|
||||
responseSuccess(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1150,6 +1401,55 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
||||
{
|
||||
auto &f = std::get<WindowUpdateFrame>(frame);
|
||||
stream.avaliableTxWindow += f.windowSizeIncrement;
|
||||
if (avaliableTxWindow == 0)
|
||||
return;
|
||||
|
||||
auto it = pendingDataSend.find(streamId);
|
||||
if (it == pendingDataSend.end())
|
||||
return;
|
||||
|
||||
size_t i = 0;
|
||||
size_t sendOffset = it->second;
|
||||
assert(stream.request != nullptr);
|
||||
assert(stream.request->body().length() > sendOffset);
|
||||
size_t maxSendSize = stream.request->body().length() - sendOffset;
|
||||
maxSendSize = (std::min)(maxSendSize, stream.avaliableTxWindow);
|
||||
maxSendSize = (std::min)(maxSendSize, avaliableTxWindow);
|
||||
bool sendEverything =
|
||||
maxSendSize == stream.request->body().length() - sendOffset;
|
||||
for (i = 0; i < maxSendSize; i += maxFrameSize)
|
||||
{
|
||||
size_t readSize =
|
||||
(std::min)(maxFrameSize,
|
||||
stream.request->body().length() - sendOffset - i);
|
||||
std::vector<uint8_t> buffer;
|
||||
buffer.resize(readSize);
|
||||
memcpy(buffer.data(),
|
||||
stream.request->body().data() + sendOffset + i,
|
||||
readSize);
|
||||
DataFrame dataFrame;
|
||||
dataFrame.data = std::move(buffer);
|
||||
dataFrame.endStream =
|
||||
(i + maxFrameSize >=
|
||||
stream.request->body().length() - sendOffset);
|
||||
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
|
||||
<< " endStream=" << dataFrame.endStream;
|
||||
connPtr->send(serializeFrame(dataFrame, streamId));
|
||||
|
||||
stream.avaliableTxWindow -= dataFrame.data.size();
|
||||
avaliableTxWindow -= dataFrame.data.size();
|
||||
}
|
||||
|
||||
if (sendEverything)
|
||||
pendingDataSend.erase(it);
|
||||
else
|
||||
it->second = sendOffset + i;
|
||||
}
|
||||
else if (std::holds_alternative<RstStreamFrame>(frame))
|
||||
{
|
||||
auto &f = std::get<RstStreamFrame>(frame);
|
||||
LOG_TRACE << "RST_STREAM frame received: errorCode=" << f.errorCode;
|
||||
responseErrored(streamId, ReqResult::BadResponse);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1170,7 +1470,7 @@ internal::H2Stream &Http2Transport::createStream(int32_t streamId)
|
||||
return stream;
|
||||
}
|
||||
|
||||
void Http2Transport::streamFinished(internal::H2Stream &stream)
|
||||
void Http2Transport::responseSuccess(internal::H2Stream &stream)
|
||||
{
|
||||
assert(stream.request != nullptr);
|
||||
assert(stream.callback);
|
||||
@ -1182,6 +1482,10 @@ void Http2Transport::streamFinished(internal::H2Stream &stream)
|
||||
respCallback(stream.response, {stream.request, stream.callback}, connPtr);
|
||||
streams.erase(it);
|
||||
|
||||
auto it2 = pendingDataSend.find(streamId);
|
||||
if (it2 != pendingDataSend.end())
|
||||
pendingDataSend.erase(it2);
|
||||
|
||||
if (bufferedRequests.empty())
|
||||
return;
|
||||
auto &[req, cb] = bufferedRequests.front();
|
||||
@ -1189,15 +1493,17 @@ void Http2Transport::streamFinished(internal::H2Stream &stream)
|
||||
bufferedRequests.pop();
|
||||
}
|
||||
|
||||
void Http2Transport::streamFinished(int32_t streamId,
|
||||
ReqResult result,
|
||||
StreamCloseErrorCode errorCode)
|
||||
void Http2Transport::responseErrored(int32_t streamId, ReqResult result)
|
||||
{
|
||||
auto it = streams.find(streamId);
|
||||
assert(it != streams.end());
|
||||
it->second.callback(result, nullptr);
|
||||
streams.erase(it);
|
||||
|
||||
auto it2 = pendingDataSend.find(streamId);
|
||||
if (it2 != pendingDataSend.end())
|
||||
pendingDataSend.erase(it2);
|
||||
|
||||
if (bufferedRequests.empty())
|
||||
return;
|
||||
auto &[req, cb] = bufferedRequests.front();
|
||||
@ -1218,6 +1524,8 @@ void Http2Transport::onError(ReqResult result)
|
||||
cb(result, nullptr);
|
||||
bufferedRequests.pop();
|
||||
}
|
||||
|
||||
pendingDataSend.clear();
|
||||
}
|
||||
|
||||
void Http2Transport::killConnection(int32_t lastStreamId,
|
||||
@ -1239,5 +1547,7 @@ bool Http2Transport::handleConnectionClose()
|
||||
return false;
|
||||
for (auto &[streamId, stream] : streams)
|
||||
stream.callback(ReqResult::BadResponse, nullptr);
|
||||
streams.clear();
|
||||
pendingDataSend.clear();
|
||||
return true;
|
||||
}
|
@ -78,12 +78,46 @@ struct PingFrame
|
||||
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||
};
|
||||
|
||||
struct ContinuationFrame
|
||||
{
|
||||
std::vector<uint8_t> headerBlockFragment;
|
||||
bool endHeaders = false;
|
||||
|
||||
static std::optional<ContinuationFrame> parse(ByteStream &payload,
|
||||
uint8_t flags);
|
||||
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||
};
|
||||
|
||||
struct RstStreamFrame
|
||||
{
|
||||
uint32_t errorCode = 0;
|
||||
|
||||
static std::optional<RstStreamFrame> parse(ByteStream &payload,
|
||||
uint8_t flags);
|
||||
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||
};
|
||||
|
||||
struct PushPromiseFrame
|
||||
{
|
||||
uint8_t padLength = 0;
|
||||
bool endHeaders = false;
|
||||
int32_t promisedStreamId = 0;
|
||||
std::vector<uint8_t> headerBlockFragment;
|
||||
|
||||
static std::optional<PushPromiseFrame> parse(ByteStream &payload,
|
||||
uint8_t flags);
|
||||
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||
};
|
||||
|
||||
using H2Frame = std::variant<SettingsFrame,
|
||||
WindowUpdateFrame,
|
||||
HeadersFrame,
|
||||
GoAwayFrame,
|
||||
DataFrame,
|
||||
PingFrame>;
|
||||
PingFrame,
|
||||
ContinuationFrame,
|
||||
PushPromiseFrame,
|
||||
RstStreamFrame>;
|
||||
|
||||
enum class StreamState
|
||||
{
|
||||
@ -138,11 +172,13 @@ class Http2Transport : public HttpTransport
|
||||
hpack::HPacker hpackTx;
|
||||
hpack::HPacker hpackRx;
|
||||
|
||||
std::priority_queue<int32_t> usibleStreamIds;
|
||||
int32_t currentStreamId = 1;
|
||||
std::unordered_map<int32_t, internal::H2Stream> streams;
|
||||
bool serverSettingsReceived = false;
|
||||
std::queue<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
|
||||
trantor::MsgBuffer headerBufferRx;
|
||||
int32_t expectngContinuationStreamId = 0;
|
||||
|
||||
std::unordered_map<int32_t, size_t> pendingDataSend;
|
||||
// TODO: Handle server-initiated stream creation
|
||||
|
||||
// HTTP/2 client-wide settings (can be changed by server)
|
||||
@ -160,10 +196,8 @@ class Http2Transport : public HttpTransport
|
||||
size_t avaliableRxWindow = 65535;
|
||||
|
||||
internal::H2Stream &createStream(int32_t streamId);
|
||||
void streamFinished(internal::H2Stream &stream);
|
||||
void streamFinished(int32_t streamId,
|
||||
ReqResult result,
|
||||
StreamCloseErrorCode errorCode);
|
||||
void responseSuccess(internal::H2Stream &stream);
|
||||
void responseErrored(int32_t streamId, ReqResult result);
|
||||
|
||||
int32_t nextStreamId()
|
||||
{
|
||||
@ -181,6 +215,10 @@ class Http2Transport : public HttpTransport
|
||||
StreamCloseErrorCode errorCode,
|
||||
std::string errorMsg = "");
|
||||
|
||||
bool parseAndApplyHeaders(internal::H2Stream &stream,
|
||||
const void *data,
|
||||
size_t len);
|
||||
|
||||
public:
|
||||
Http2Transport(trantor::TcpConnectionPtr connPtr,
|
||||
size_t *bytesSent,
|
||||
|
Loading…
x
Reference in New Issue
Block a user