mirror of
https://github.com/drogonframework/drogon.git
synced 2025-07-23 00:01:23 -04:00
Compare commits
16 Commits
d50577b709
...
52d0fdd25d
Author | SHA1 | Date | |
---|---|---|---|
|
52d0fdd25d | ||
|
0eb2cdabe7 | ||
|
e12260a0b7 | ||
|
c82e8208ab | ||
|
e13a8b930f | ||
|
4a2eecf03d | ||
|
45a2b1d0d3 | ||
|
d69699ceb2 | ||
|
5f52a80358 | ||
|
ab566b3524 | ||
|
e565f38d7a | ||
|
5d434577ff | ||
|
2eef512537 | ||
|
847b580cf5 | ||
|
83606eb5a6 | ||
|
f5c4863ad0 |
@ -16,7 +16,7 @@ int main()
|
|||||||
{
|
{
|
||||||
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
|
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
|
||||||
{
|
{
|
||||||
auto client = HttpClient::newHttpClient("https://clehaxze.tw",
|
auto client = HttpClient::newHttpClient("https://clehaxze.tw:8844",
|
||||||
nullptr,
|
nullptr,
|
||||||
false,
|
false,
|
||||||
false);
|
false);
|
||||||
|
@ -18,7 +18,11 @@ static std::optional<size_t> stosz(const std::string &str)
|
|||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
return std::stoull(str);
|
size_t idx = 0;
|
||||||
|
size_t res = std::stoull(str, &idx);
|
||||||
|
if (idx != str.size())
|
||||||
|
return std::nullopt;
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
catch (const std::exception &)
|
catch (const std::exception &)
|
||||||
{
|
{
|
||||||
@ -107,7 +111,7 @@ struct ByteStream
|
|||||||
|
|
||||||
uint32_t readU24BE()
|
uint32_t readU24BE()
|
||||||
{
|
{
|
||||||
assert(offset <= length - 3);
|
assert(length >= 3 && offset <= length - 3);
|
||||||
uint32_t res =
|
uint32_t res =
|
||||||
ptr[offset] << 16 | ptr[offset + 1] << 8 | ptr[offset + 2];
|
ptr[offset] << 16 | ptr[offset + 1] << 8 | ptr[offset + 2];
|
||||||
offset += 3;
|
offset += 3;
|
||||||
@ -116,7 +120,7 @@ struct ByteStream
|
|||||||
|
|
||||||
uint32_t readU32BE()
|
uint32_t readU32BE()
|
||||||
{
|
{
|
||||||
assert(offset <= length - 4);
|
assert(length >= 4 && offset <= length - 4);
|
||||||
uint32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
|
uint32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
|
||||||
ptr[offset + 2] << 8 | ptr[offset + 3];
|
ptr[offset + 2] << 8 | ptr[offset + 3];
|
||||||
offset += 4;
|
offset += 4;
|
||||||
@ -125,7 +129,7 @@ struct ByteStream
|
|||||||
|
|
||||||
std::pair<bool, int32_t> readBI32BE()
|
std::pair<bool, int32_t> readBI32BE()
|
||||||
{
|
{
|
||||||
assert(offset <= length - 4);
|
assert(length >= 4 && offset <= length - 4);
|
||||||
int32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
|
int32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
|
||||||
ptr[offset + 2] << 8 | ptr[offset + 3];
|
ptr[offset + 2] << 8 | ptr[offset + 3];
|
||||||
offset += 4;
|
offset += 4;
|
||||||
@ -137,7 +141,7 @@ struct ByteStream
|
|||||||
|
|
||||||
uint16_t readU16BE()
|
uint16_t readU16BE()
|
||||||
{
|
{
|
||||||
assert(offset <= length - 2);
|
assert(length >= 2 && offset <= length - 2);
|
||||||
uint16_t res = ptr[offset] << 8 | ptr[offset + 1];
|
uint16_t res = ptr[offset] << 8 | ptr[offset + 1];
|
||||||
offset += 2;
|
offset += 2;
|
||||||
return res;
|
return res;
|
||||||
@ -145,13 +149,13 @@ struct ByteStream
|
|||||||
|
|
||||||
uint8_t readU8()
|
uint8_t readU8()
|
||||||
{
|
{
|
||||||
assert(offset <= length - 1);
|
assert(length >= 1 && offset <= length - 1);
|
||||||
return ptr[offset++];
|
return ptr[offset++];
|
||||||
}
|
}
|
||||||
|
|
||||||
void read(uint8_t *buffer, size_t size)
|
void read(uint8_t *buffer, size_t size)
|
||||||
{
|
{
|
||||||
assert(offset <= length - size || size == 0);
|
assert((length >= size && offset <= length - size) || size == 0);
|
||||||
memcpy(buffer, ptr + offset, size);
|
memcpy(buffer, ptr + offset, size);
|
||||||
offset += size;
|
offset += size;
|
||||||
}
|
}
|
||||||
@ -171,7 +175,7 @@ struct ByteStream
|
|||||||
|
|
||||||
void skip(size_t n)
|
void skip(size_t n)
|
||||||
{
|
{
|
||||||
assert(offset <= length - n || n == 0);
|
assert((length >= n && offset <= length - n) || n == 0);
|
||||||
offset += n;
|
offset += n;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,6 +231,7 @@ struct OByteStream
|
|||||||
{
|
{
|
||||||
assert(value <= 0xffffff);
|
assert(value <= 0xffffff);
|
||||||
assert(offset <= buffer.readableBytes() - 3);
|
assert(offset <= buffer.readableBytes() - 3);
|
||||||
|
assert(buffer.writableBytes() >= 3);
|
||||||
auto ptr = (uint8_t *)buffer.peek() + offset;
|
auto ptr = (uint8_t *)buffer.peek() + offset;
|
||||||
ptr[0] = value >> 16;
|
ptr[0] = value >> 16;
|
||||||
ptr[1] = value >> 8;
|
ptr[1] = value >> 8;
|
||||||
@ -236,6 +241,7 @@ struct OByteStream
|
|||||||
void overwriteU8(size_t offset, uint8_t value)
|
void overwriteU8(size_t offset, uint8_t value)
|
||||||
{
|
{
|
||||||
assert(offset <= buffer.readableBytes() - 1);
|
assert(offset <= buffer.readableBytes() - 1);
|
||||||
|
assert(buffer.writableBytes() >= 1);
|
||||||
auto ptr = (uint8_t *)buffer.peek() + offset;
|
auto ptr = (uint8_t *)buffer.peek() + offset;
|
||||||
ptr[0] = value;
|
ptr[0] = value;
|
||||||
}
|
}
|
||||||
@ -325,11 +331,25 @@ std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
|
|||||||
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
|
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
|
||||||
bool priority = flags & (uint8_t)H2HeadersFlags::Priority;
|
bool priority = flags & (uint8_t)H2HeadersFlags::Priority;
|
||||||
|
|
||||||
|
if (payload.size() == 0)
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Header size cannot be 0";
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
HeadersFrame frame;
|
HeadersFrame frame;
|
||||||
if (padded)
|
if (padded)
|
||||||
{
|
{
|
||||||
frame.padLength = payload.readU8();
|
frame.padLength = payload.readU8();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t minSize = frame.padLength + (priority ? 5 : 0);
|
||||||
|
if (payload.size() < minSize)
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Invalid headers frame length";
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
if (priority)
|
if (priority)
|
||||||
{
|
{
|
||||||
auto [exclusive, streamDependency] = payload.readBI32BE();
|
auto [exclusive, streamDependency] = payload.readBI32BE();
|
||||||
@ -346,6 +366,7 @@ std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
|
|||||||
frame.endStream = true;
|
frame.endStream = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(payload.remaining() >= frame.padLength);
|
||||||
int64_t payloadSize = payload.remaining() - frame.padLength;
|
int64_t payloadSize = payload.remaining() - frame.padLength;
|
||||||
if (payloadSize < 0)
|
if (payloadSize < 0)
|
||||||
{
|
{
|
||||||
@ -421,11 +442,20 @@ std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
|
|||||||
{
|
{
|
||||||
frame.padLength = payload.readU8();
|
frame.padLength = payload.readU8();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t minSize = frame.padLength;
|
||||||
|
if (payload.size() < minSize)
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Invalid data frame length";
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
if (endStream)
|
if (endStream)
|
||||||
{
|
{
|
||||||
frame.endStream = true;
|
frame.endStream = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(payload.remaining() >= frame.padLength);
|
||||||
size_t payloadSize = payload.remaining() - frame.padLength;
|
size_t payloadSize = payload.remaining() - frame.padLength;
|
||||||
if (payloadSize < 0)
|
if (payloadSize < 0)
|
||||||
{
|
{
|
||||||
@ -447,7 +477,7 @@ std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
|
|||||||
|
|
||||||
bool DataFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
bool DataFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||||
{
|
{
|
||||||
flags = 0x0;
|
flags = (endStream ? (uint8_t)H2DataFlags::EndStream : 0x0);
|
||||||
stream.write(data.data(), data.size());
|
stream.write(data.data(), data.size());
|
||||||
if (padLength > 0)
|
if (padLength > 0)
|
||||||
{
|
{
|
||||||
@ -477,6 +507,95 @@ bool PingFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<ContinuationFrame> ContinuationFrame::parse(ByteStream &payload,
|
||||||
|
uint8_t flags)
|
||||||
|
{
|
||||||
|
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
|
||||||
|
ContinuationFrame frame;
|
||||||
|
if (endHeaders)
|
||||||
|
{
|
||||||
|
frame.endHeaders = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.headerBlockFragment.resize(payload.remaining());
|
||||||
|
payload.read(frame.headerBlockFragment.data(),
|
||||||
|
frame.headerBlockFragment.size());
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ContinuationFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||||
|
{
|
||||||
|
flags = 0x0;
|
||||||
|
if (endHeaders)
|
||||||
|
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
|
||||||
|
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<PushPromiseFrame> PushPromiseFrame::parse(ByteStream &payload,
|
||||||
|
uint8_t flags)
|
||||||
|
{
|
||||||
|
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
|
||||||
|
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
|
||||||
|
|
||||||
|
PushPromiseFrame frame;
|
||||||
|
if (padded)
|
||||||
|
frame.padLength = payload.readU8();
|
||||||
|
if (endHeaders)
|
||||||
|
frame.endHeaders = true;
|
||||||
|
|
||||||
|
size_t minSize = frame.padLength + 4;
|
||||||
|
if (payload.size() < minSize)
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Invalid push promise frame length";
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [_, promisedStreamId] = payload.readBI32BE();
|
||||||
|
frame.promisedStreamId = promisedStreamId;
|
||||||
|
assert(payload.remaining() >= frame.padLength);
|
||||||
|
frame.headerBlockFragment.resize(payload.remaining() - frame.padLength);
|
||||||
|
payload.read(frame.headerBlockFragment.data(),
|
||||||
|
frame.headerBlockFragment.size());
|
||||||
|
payload.skip(frame.padLength);
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PushPromiseFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||||
|
{
|
||||||
|
flags = 0x0;
|
||||||
|
if (endHeaders)
|
||||||
|
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
|
||||||
|
assert(promisedStreamId > 0);
|
||||||
|
stream.writeU32BE(promisedStreamId);
|
||||||
|
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
|
||||||
|
if (padLength > 0)
|
||||||
|
{
|
||||||
|
flags |= (uint8_t)H2HeadersFlags::Padded;
|
||||||
|
stream.writeU8(padLength);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<RstStreamFrame> RstStreamFrame::parse(ByteStream &payload,
|
||||||
|
uint8_t flags)
|
||||||
|
{
|
||||||
|
if (payload.size() != 4)
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Invalid RST_STREAM frame length";
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
RstStreamFrame frame;
|
||||||
|
frame.errorCode = payload.readU32BE();
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RstStreamFrame::serialize(OByteStream &stream, uint8_t &flags) const
|
||||||
|
{
|
||||||
|
flags = 0x0;
|
||||||
|
stream.writeU32BE(errorCode);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
} // namespace drogon::internal
|
} // namespace drogon::internal
|
||||||
|
|
||||||
// Print the HEX and ASCII representation of the buffer side by side
|
// Print the HEX and ASCII representation of the buffer side by side
|
||||||
@ -568,6 +687,24 @@ static trantor::MsgBuffer serializeFrame(const H2Frame &frame, int32_t streamId)
|
|||||||
ok = f.serialize(buffer, flags);
|
ok = f.serialize(buffer, flags);
|
||||||
type = (uint8_t)H2FrameType::Ping;
|
type = (uint8_t)H2FrameType::Ping;
|
||||||
}
|
}
|
||||||
|
else if (std::holds_alternative<ContinuationFrame>(frame))
|
||||||
|
{
|
||||||
|
const auto &f = std::get<ContinuationFrame>(frame);
|
||||||
|
ok = f.serialize(buffer, flags);
|
||||||
|
type = (uint8_t)H2FrameType::Continuation;
|
||||||
|
}
|
||||||
|
else if (std::holds_alternative<PushPromiseFrame>(frame))
|
||||||
|
{
|
||||||
|
const auto &f = std::get<PushPromiseFrame>(frame);
|
||||||
|
ok = f.serialize(buffer, flags);
|
||||||
|
type = (uint8_t)H2FrameType::PushPromise;
|
||||||
|
}
|
||||||
|
else if (std::holds_alternative<RstStreamFrame>(frame))
|
||||||
|
{
|
||||||
|
const auto &f = std::get<RstStreamFrame>(frame);
|
||||||
|
ok = f.serialize(buffer, flags);
|
||||||
|
type = (uint8_t)H2FrameType::RstStream;
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
LOG_ERROR << "Unsupported frame type";
|
LOG_ERROR << "Unsupported frame type";
|
||||||
@ -646,6 +783,12 @@ static std::tuple<std::optional<H2Frame>, uint32_t, uint8_t, bool> parseH2Frame(
|
|||||||
frame = DataFrame::parse(payload, flags);
|
frame = DataFrame::parse(payload, flags);
|
||||||
else if (type == (uint8_t)H2FrameType::Ping)
|
else if (type == (uint8_t)H2FrameType::Ping)
|
||||||
frame = PingFrame::parse(payload, flags);
|
frame = PingFrame::parse(payload, flags);
|
||||||
|
else if (type == (uint8_t)H2FrameType::Continuation)
|
||||||
|
frame = ContinuationFrame::parse(payload, flags);
|
||||||
|
else if (type == (uint8_t)H2FrameType::PushPromise)
|
||||||
|
frame = PushPromiseFrame::parse(payload, flags);
|
||||||
|
else if (type == (uint8_t)H2FrameType::RstStream)
|
||||||
|
frame = RstStreamFrame::parse(payload, flags);
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
LOG_WARN << "Unsupported H2 frame type: " << (int)type;
|
LOG_WARN << "Unsupported H2 frame type: " << (int)type;
|
||||||
@ -670,13 +813,7 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
|||||||
HttpReqCallback &&callback)
|
HttpReqCallback &&callback)
|
||||||
{
|
{
|
||||||
connPtr->getLoop()->assertInLoopThread();
|
connPtr->getLoop()->assertInLoopThread();
|
||||||
if (!serverSettingsReceived)
|
if (streams.size() + 1 >= maxConcurrentStreams)
|
||||||
{
|
|
||||||
bufferedRequests.push({req, std::move(callback)});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (streams.size() >= maxConcurrentStreams)
|
|
||||||
{
|
{
|
||||||
LOG_TRACE << "Too many streams in flight. Buffering request";
|
LOG_TRACE << "Too many streams in flight. Buffering request";
|
||||||
bufferedRequests.push({req, std::move(callback)});
|
bufferedRequests.push({req, std::move(callback)});
|
||||||
@ -697,15 +834,11 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto headers = req->headers();
|
auto headers = req->headers();
|
||||||
HeadersFrame frame;
|
std::vector<uint8_t> encodedHeaders(maxCompressiedHeaderSize);
|
||||||
frame.padLength = 0;
|
|
||||||
frame.exclusive = false;
|
|
||||||
frame.streamDependency = 0;
|
|
||||||
frame.weight = 0;
|
|
||||||
frame.headerBlockFragment.resize(maxCompressiedHeaderSize);
|
|
||||||
|
|
||||||
LOG_TRACE << "Sending HTTP/2 headers: size=" << headers.size();
|
LOG_TRACE << "Sending HTTP/2 headers: size=" << headers.size();
|
||||||
hpack::HPacker::KeyValueVector headersToEncode;
|
hpack::HPacker::KeyValueVector headersToEncode;
|
||||||
|
headersToEncode.reserve(headers.size() + 5);
|
||||||
const std::array<std::string_view, 2> headersToSkip = {
|
const std::array<std::string_view, 2> headersToSkip = {
|
||||||
{"host", "connection"}};
|
{"host", "connection"}};
|
||||||
headersToEncode.emplace_back(":method", req->methodString());
|
headersToEncode.emplace_back(":method", req->methodString());
|
||||||
@ -726,8 +859,8 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
|||||||
for (auto &[key, value] : headersToEncode)
|
for (auto &[key, value] : headersToEncode)
|
||||||
LOG_TRACE << " " << key << ": " << value;
|
LOG_TRACE << " " << key << ": " << value;
|
||||||
int n = hpackTx.encode(headersToEncode,
|
int n = hpackTx.encode(headersToEncode,
|
||||||
frame.headerBlockFragment.data(),
|
encodedHeaders.data(),
|
||||||
frame.headerBlockFragment.size());
|
encodedHeaders.size());
|
||||||
if (n < 0)
|
if (n < 0)
|
||||||
{
|
{
|
||||||
LOG_TRACE << "Failed to encode headers. Internal error or header "
|
LOG_TRACE << "Failed to encode headers. Internal error or header "
|
||||||
@ -742,17 +875,54 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
|||||||
abort();
|
abort();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
frame.headerBlockFragment.resize(n);
|
|
||||||
frame.endHeaders = true;
|
|
||||||
|
|
||||||
|
encodedHeaders.resize(n);
|
||||||
|
LOG_TRACE << "Encoded headers size: " << encodedHeaders.size();
|
||||||
|
|
||||||
|
bool haveBody = req->body().length() > 0;
|
||||||
auto &stream = createStream(streamId);
|
auto &stream = createStream(streamId);
|
||||||
if (req->body().length() == 0)
|
bool needsContinuation = encodedHeaders.size() > maxFrameSize;
|
||||||
frame.endStream = true;
|
for (size_t i = 0; i < encodedHeaders.size(); i += maxFrameSize)
|
||||||
LOG_TRACE << "Sending headers frame";
|
{
|
||||||
auto f = serializeFrame(frame, streamId);
|
bool isFirst = i == 0;
|
||||||
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
|
bool isLast = i + maxFrameSize >= encodedHeaders.size();
|
||||||
connPtr->send(f);
|
size_t dataSize = (std::min)(maxFrameSize, encodedHeaders.size() - i);
|
||||||
|
|
||||||
|
auto frame = [&]() -> H2Frame {
|
||||||
|
if (isFirst)
|
||||||
|
{
|
||||||
|
HeadersFrame frame;
|
||||||
|
frame.headerBlockFragment.resize(dataSize);
|
||||||
|
memcpy(frame.headerBlockFragment.data(),
|
||||||
|
encodedHeaders.data() + i,
|
||||||
|
dataSize);
|
||||||
|
frame.endHeaders = isLast;
|
||||||
|
frame.endStream = (!haveBody && isLast);
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
ContinuationFrame frame;
|
||||||
|
frame.headerBlockFragment.resize(dataSize);
|
||||||
|
assert(encodedHeaders.size() > i + dataSize);
|
||||||
|
memcpy(frame.headerBlockFragment.data(),
|
||||||
|
encodedHeaders.data() + i,
|
||||||
|
dataSize);
|
||||||
|
frame.endHeaders = (!haveBody && isLast);
|
||||||
|
return frame;
|
||||||
|
}();
|
||||||
|
|
||||||
|
auto f = serializeFrame(frame, streamId);
|
||||||
|
LOG_TRACE << "Sending " << (isFirst ? "HEADERS" : "CONTINUATION")
|
||||||
|
<< " frame:";
|
||||||
|
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
|
||||||
|
connPtr->send(serializeFrame(frame, streamId));
|
||||||
|
}
|
||||||
|
if (needsContinuation && !haveBody)
|
||||||
|
{
|
||||||
|
DataFrame frame;
|
||||||
|
frame.endStream = true;
|
||||||
|
connPtr->send(serializeFrame(frame, streamId));
|
||||||
|
return;
|
||||||
|
}
|
||||||
stream.callback = std::move(callback);
|
stream.callback = std::move(callback);
|
||||||
stream.request = req;
|
stream.request = req;
|
||||||
|
|
||||||
@ -763,29 +933,41 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (req->body().length() > stream.avaliableTxWindow)
|
size_t bodySize = req->body().length();
|
||||||
{
|
bool sendEverything =
|
||||||
LOG_ERROR << "HTTP/2 body too large to fit in INITIAL_WINDOW_SIZE. Not "
|
bodySize <= stream.avaliableTxWindow && bodySize <= avaliableTxWindow;
|
||||||
"supported yet.";
|
size_t maxSendSize = bodySize;
|
||||||
abort();
|
maxSendSize = (std::min)(maxSendSize, stream.avaliableTxWindow);
|
||||||
return;
|
maxSendSize = (std::min)(maxSendSize, avaliableTxWindow);
|
||||||
}
|
|
||||||
|
|
||||||
DataFrame dataFrame;
|
DataFrame dataFrame;
|
||||||
for (size_t i = 0; i < req->body().length(); i += maxFrameSize)
|
size_t i;
|
||||||
|
for (i = 0; i < maxSendSize; i += maxFrameSize)
|
||||||
{
|
{
|
||||||
size_t readSize = (std::min)(maxFrameSize, req->body().length() - i);
|
size_t readSize = (std::min)(maxFrameSize, bodySize - i);
|
||||||
std::vector<uint8_t> buffer;
|
std::vector<uint8_t> buffer;
|
||||||
buffer.resize(readSize);
|
buffer.resize(readSize);
|
||||||
memcpy(buffer.data(), req->body().data() + i, readSize);
|
memcpy(buffer.data(), req->body().data() + i, readSize);
|
||||||
dataFrame.data = std::move(buffer);
|
dataFrame.data = std::move(buffer);
|
||||||
dataFrame.endStream = (i + maxFrameSize >= req->body().length());
|
dataFrame.endStream = (i + maxFrameSize >= bodySize);
|
||||||
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
|
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
|
||||||
<< " endStream=" << dataFrame.endStream;
|
<< " endStream=" << dataFrame.endStream;
|
||||||
connPtr->send(serializeFrame(dataFrame, streamId));
|
connPtr->send(serializeFrame(dataFrame, streamId));
|
||||||
|
|
||||||
stream.avaliableTxWindow -= dataFrame.data.size();
|
stream.avaliableTxWindow -= dataFrame.data.size();
|
||||||
avaliableRxWindow -= dataFrame.data.size();
|
avaliableTxWindow -= dataFrame.data.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sendEverything)
|
||||||
|
{
|
||||||
|
auto it = pendingDataSend.find(streamId);
|
||||||
|
if (it != pendingDataSend.end())
|
||||||
|
{
|
||||||
|
LOG_FATAL << "Stream id already in use! This should not happen";
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingDataSend.emplace(streamId, i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -806,6 +988,9 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
avaliableRxWindow += windowIncreaseSize;
|
avaliableRxWindow += windowIncreaseSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (msg->readableBytes() == 0)
|
||||||
|
break;
|
||||||
|
|
||||||
// FIXME: The code cannot distinguish between a out-of-data and
|
// FIXME: The code cannot distinguish between a out-of-data and
|
||||||
// unsupported frame type. We need to fix this as it should be handled
|
// unsupported frame type. We need to fix this as it should be handled
|
||||||
// differently.
|
// differently.
|
||||||
@ -825,6 +1010,7 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
}
|
}
|
||||||
auto &frame = *frameOpt;
|
auto &frame = *frameOpt;
|
||||||
|
|
||||||
|
// special case for PING and GOAWAY. These are all global frames
|
||||||
if (std::holds_alternative<GoAwayFrame>(frame))
|
if (std::holds_alternative<GoAwayFrame>(frame))
|
||||||
{
|
{
|
||||||
auto &f = std::get<GoAwayFrame>(frame);
|
auto &f = std::get<GoAwayFrame>(frame);
|
||||||
@ -840,18 +1026,13 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
{
|
{
|
||||||
if (streamId > f.lastStreamId)
|
if (streamId > f.lastStreamId)
|
||||||
{
|
{
|
||||||
streamFinished(streamId,
|
responseErrored(streamId, ReqResult::BadResponse);
|
||||||
ReqResult::BadResponse,
|
|
||||||
StreamCloseErrorCode::RefusedStream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: Should be half-closed but transport doesn't support it yet
|
// TODO: Should be half-closed but transport doesn't support it yet
|
||||||
connPtr->shutdown();
|
connPtr->shutdown();
|
||||||
}
|
}
|
||||||
|
else if (std::holds_alternative<PingFrame>(frame))
|
||||||
// special case for PING frame. It is the only frame that is not
|
|
||||||
// associated with a stream
|
|
||||||
if (std::holds_alternative<PingFrame>(frame))
|
|
||||||
{
|
{
|
||||||
auto &f = std::get<PingFrame>(frame);
|
auto &f = std::get<PingFrame>(frame);
|
||||||
if (f.ack)
|
if (f.ack)
|
||||||
@ -867,6 +1048,32 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we are expecting a CONTINUATION frame, we should not receive
|
||||||
|
// HEADERS or CONTINUATION from other streams
|
||||||
|
if (expectngContinuationStreamId != 0 &&
|
||||||
|
(std::holds_alternative<HeadersFrame>(frame) ||
|
||||||
|
(std::holds_alternative<ContinuationFrame>(frame) &&
|
||||||
|
streamId != expectngContinuationStreamId)))
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Protocol error: unexpected HEADERS or "
|
||||||
|
"CONTINUATION frame";
|
||||||
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"Expecting CONTINUATION frame for stream " +
|
||||||
|
std::to_string(expectngContinuationStreamId));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::holds_alternative<PushPromiseFrame>(frame))
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Push promise frame received. Not supported yet. "
|
||||||
|
"Connection will die";
|
||||||
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"Push promise not supported");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (streamId != 0)
|
if (streamId != 0)
|
||||||
{
|
{
|
||||||
handleFrameForStream(frame, streamId, flags);
|
handleFrameForStream(frame, streamId, flags);
|
||||||
@ -878,6 +1085,14 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
{
|
{
|
||||||
auto &f = std::get<WindowUpdateFrame>(frame);
|
auto &f = std::get<WindowUpdateFrame>(frame);
|
||||||
avaliableTxWindow += f.windowSizeIncrement;
|
avaliableTxWindow += f.windowSizeIncrement;
|
||||||
|
|
||||||
|
// HACK: Notify stream we have more window size available
|
||||||
|
auto it = pendingDataSend.begin();
|
||||||
|
if (it == pendingDataSend.end())
|
||||||
|
continue;
|
||||||
|
auto hackFrame = f;
|
||||||
|
hackFrame.windowSizeIncrement = 0;
|
||||||
|
handleFrameForStream(hackFrame, it->first, 0);
|
||||||
}
|
}
|
||||||
else if (std::holds_alternative<SettingsFrame>(frame))
|
else if (std::holds_alternative<SettingsFrame>(frame))
|
||||||
{
|
{
|
||||||
@ -888,7 +1103,7 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
{
|
{
|
||||||
hpackRx.setMaxTableSize(value);
|
hpackRx.setMaxTableSize(value);
|
||||||
}
|
}
|
||||||
if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
|
else if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
|
||||||
{
|
{
|
||||||
// Note: MAX_CONCURRENT_STREAMS can be 0, which means
|
// Note: MAX_CONCURRENT_STREAMS can be 0, which means
|
||||||
// the client is not allowed to send any request. I doubt
|
// the client is not allowed to send any request. I doubt
|
||||||
@ -927,27 +1142,6 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
LOG_TRACE << "Unsupported settings key: " << key;
|
LOG_TRACE << "Unsupported settings key: " << key;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!serverSettingsReceived)
|
|
||||||
{
|
|
||||||
LOG_TRACE
|
|
||||||
<< "Server settings received. Sending our own WindowUpdate";
|
|
||||||
|
|
||||||
WindowUpdateFrame windowUpdateFrame;
|
|
||||||
windowUpdateFrame.windowSizeIncrement = windowIncreaseSize;
|
|
||||||
connPtr->send(serializeFrame(windowUpdateFrame, 0));
|
|
||||||
avaliableRxWindow = initialWindowSize;
|
|
||||||
|
|
||||||
serverSettingsReceived = true;
|
|
||||||
while (!bufferedRequests.empty() &&
|
|
||||||
streams.size() < maxConcurrentStreams)
|
|
||||||
{
|
|
||||||
auto &[req, cb] = bufferedRequests.front();
|
|
||||||
sendRequestInLoop(req, std::move(cb));
|
|
||||||
bufferedRequests.pop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Somehow nghttp2 wants us to send ACK after sending our
|
// Somehow nghttp2 wants us to send ACK after sending our
|
||||||
// preferences??
|
// preferences??
|
||||||
if (flags == 1)
|
if (flags == 1)
|
||||||
@ -961,12 +1155,30 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
{
|
{
|
||||||
// Should never show up on stream 0
|
// Should never show up on stream 0
|
||||||
LOG_FATAL << "Protocol error: HEADERS frame on stream 0";
|
LOG_FATAL << "Protocol error: HEADERS frame on stream 0";
|
||||||
errorCallback(ReqResult::BadResponse);
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"HEADERS frame on stream 0");
|
||||||
}
|
}
|
||||||
else if (std::holds_alternative<DataFrame>(frame))
|
else if (std::holds_alternative<DataFrame>(frame))
|
||||||
{
|
{
|
||||||
LOG_FATAL << "Protocol error: DATA frame on stream 0";
|
LOG_FATAL << "Protocol error: DATA frame on stream 0";
|
||||||
errorCallback(ReqResult::BadResponse);
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"DATA frame on stream 0");
|
||||||
|
}
|
||||||
|
else if (std::holds_alternative<ContinuationFrame>(frame))
|
||||||
|
{
|
||||||
|
LOG_FATAL << "Protocol error: CONTINUATION frame on stream 0";
|
||||||
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"CONTINUATION frame on stream 0");
|
||||||
|
}
|
||||||
|
else if (std::holds_alternative<RstStreamFrame>(frame))
|
||||||
|
{
|
||||||
|
LOG_FATAL << "Protocol error: RST_STREAM frame on stream 0";
|
||||||
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"RST_STREAM frame on stream 0");
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -975,6 +1187,65 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Http2Transport::parseAndApplyHeaders(internal::H2Stream &stream,
|
||||||
|
const void *data,
|
||||||
|
size_t size)
|
||||||
|
{
|
||||||
|
hpack::HPacker::KeyValueVector headers;
|
||||||
|
int n = hpackRx.decode((const uint8_t *)data, size, headers);
|
||||||
|
auto streamId = stream.streamId;
|
||||||
|
if (n < 0)
|
||||||
|
{
|
||||||
|
LOG_TRACE << "Failed to decode headers";
|
||||||
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::CompressionError,
|
||||||
|
"Failed to decode headers");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto &[key, value] : headers)
|
||||||
|
LOG_TRACE << " " << key << ": " << value;
|
||||||
|
assert(stream.response == nullptr);
|
||||||
|
stream.response = std::make_shared<HttpResponseImpl>();
|
||||||
|
for (const auto &[key, value] : headers)
|
||||||
|
{
|
||||||
|
// TODO: Filter more pseudo headers
|
||||||
|
if (key == "content-length")
|
||||||
|
{
|
||||||
|
auto sz = stosz(value);
|
||||||
|
if (!sz)
|
||||||
|
{
|
||||||
|
responseErrored(streamId, ReqResult::BadResponse);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
stream.contentLength = std::move(sz);
|
||||||
|
}
|
||||||
|
if (key == ":status")
|
||||||
|
{
|
||||||
|
auto status = stosz(value);
|
||||||
|
if (!status)
|
||||||
|
{
|
||||||
|
responseErrored(streamId, ReqResult::BadResponse);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// TODO: Validate status code
|
||||||
|
stream.response->setStatusCode((HttpStatusCode)*status);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anti request smuggling. We look for \r or \n in the header
|
||||||
|
// name or value. If we find one, we abort the stream.
|
||||||
|
if (key.find_first_of("\r\n") != std::string::npos ||
|
||||||
|
value.find_first_of("\r\n") != std::string::npos)
|
||||||
|
{
|
||||||
|
responseErrored(streamId, ReqResult::BadResponse);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.response->addHeader(key, value);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
Http2Transport::Http2Transport(trantor::TcpConnectionPtr connPtr,
|
Http2Transport::Http2Transport(trantor::TcpConnectionPtr connPtr,
|
||||||
size_t *bytesSent,
|
size_t *bytesSent,
|
||||||
size_t *bytesReceived)
|
size_t *bytesReceived)
|
||||||
@ -1011,6 +1282,16 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
|||||||
|
|
||||||
if (std::holds_alternative<HeadersFrame>(frame))
|
if (std::holds_alternative<HeadersFrame>(frame))
|
||||||
{
|
{
|
||||||
|
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
|
||||||
|
{
|
||||||
|
auto &f = std::get<HeadersFrame>(frame);
|
||||||
|
headerBufferRx.append((char *)f.headerBlockFragment.data(),
|
||||||
|
f.headerBlockFragment.size());
|
||||||
|
expectngContinuationStreamId = streamId;
|
||||||
|
stream.state = StreamState::ExpectingContinuation;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (stream.state != StreamState::ExpectingHeaders)
|
if (stream.state != StreamState::ExpectingHeaders)
|
||||||
{
|
{
|
||||||
killConnection(streamId,
|
killConnection(streamId,
|
||||||
@ -1019,79 +1300,51 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto &f = std::get<HeadersFrame>(frame);
|
auto &f = std::get<HeadersFrame>(frame);
|
||||||
LOG_TRACE << "Headers frame received: size="
|
// This function handles error itself
|
||||||
<< f.headerBlockFragment.size();
|
if (!parseAndApplyHeaders(stream,
|
||||||
hpack::HPacker::KeyValueVector headers;
|
f.headerBlockFragment.data(),
|
||||||
int n = hpackRx.decode(f.headerBlockFragment.data(),
|
f.headerBlockFragment.size()))
|
||||||
f.headerBlockFragment.size(),
|
|
||||||
headers);
|
|
||||||
if (n < 0)
|
|
||||||
{
|
|
||||||
LOG_TRACE << "Failed to decode headers";
|
|
||||||
killConnection(streamId,
|
|
||||||
StreamCloseErrorCode::CompressionError,
|
|
||||||
"Failed to decode headers");
|
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
for (auto &[key, value] : headers)
|
|
||||||
LOG_TRACE << " " << key << ": " << value;
|
|
||||||
it->second.response = std::make_shared<HttpResponseImpl>();
|
|
||||||
for (const auto &[key, value] : headers)
|
|
||||||
{
|
|
||||||
// TODO: Filter more pseudo headers
|
|
||||||
if (key == "content-length")
|
|
||||||
{
|
|
||||||
auto sz = stosz(value);
|
|
||||||
if (!sz)
|
|
||||||
{
|
|
||||||
killConnection(streamId,
|
|
||||||
StreamCloseErrorCode::ProtocolError,
|
|
||||||
"Invalid content-length");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
it->second.contentLength = std::move(sz);
|
|
||||||
}
|
|
||||||
if (key == ":status")
|
|
||||||
{
|
|
||||||
// TODO: Validate status code
|
|
||||||
it->second.response->setStatusCode(
|
|
||||||
(drogon::HttpStatusCode)std::stoi(value));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Anti request smuggling. We look for \r or \n in the header
|
|
||||||
// name or value. If we find one, we abort the stream.
|
|
||||||
if (key.find_first_of("\r\n") != std::string::npos ||
|
|
||||||
value.find_first_of("\r\n") != std::string::npos)
|
|
||||||
{
|
|
||||||
streamFinished(streamId,
|
|
||||||
ReqResult::BadResponse,
|
|
||||||
StreamCloseErrorCode::ProtocolError);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
it->second.response->addHeader(key, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
|
|
||||||
{
|
|
||||||
LOG_TRACE << "We don't support CONTINUATION frames yet!";
|
|
||||||
stream.state = StreamState::ExpectingContinuation;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
// There is no body in the response.
|
// There is no body in the response.
|
||||||
if ((flags & (uint8_t)H2HeadersFlags::EndStream))
|
if ((flags & (uint8_t)H2HeadersFlags::EndStream))
|
||||||
{
|
{
|
||||||
stream.state = StreamState::Finished;
|
stream.state = StreamState::Finished;
|
||||||
streamFinished(stream);
|
responseSuccess(stream);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
stream.state = StreamState::ExpectingData;
|
stream.state = StreamState::ExpectingData;
|
||||||
}
|
}
|
||||||
|
else if (std::holds_alternative<ContinuationFrame>(frame))
|
||||||
|
{
|
||||||
|
auto &f = std::get<ContinuationFrame>(frame);
|
||||||
|
if (stream.state != StreamState::ExpectingContinuation)
|
||||||
|
{
|
||||||
|
killConnection(streamId,
|
||||||
|
StreamCloseErrorCode::ProtocolError,
|
||||||
|
"Unexpected continuation frame");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
headerBufferRx.append((char *)f.headerBlockFragment.data(),
|
||||||
|
f.headerBlockFragment.size());
|
||||||
|
bool endHeaders = (flags & (uint8_t)H2HeadersFlags::EndHeaders) != 0;
|
||||||
|
if (endHeaders)
|
||||||
|
{
|
||||||
|
stream.state = StreamState::ExpectingData;
|
||||||
|
expectngContinuationStreamId = 0;
|
||||||
|
bool ok = parseAndApplyHeaders(stream,
|
||||||
|
headerBufferRx.peek(),
|
||||||
|
headerBufferRx.readableBytes());
|
||||||
|
headerBufferRx.retrieveAll();
|
||||||
|
if (!ok)
|
||||||
|
LOG_TRACE << "Failed to parse headers in continuation frame";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
else if (std::holds_alternative<DataFrame>(frame))
|
else if (std::holds_alternative<DataFrame>(frame))
|
||||||
{
|
{
|
||||||
auto &f = std::get<DataFrame>(frame);
|
auto &f = std::get<DataFrame>(frame);
|
||||||
// TODO: Make sure this logic fits RFC
|
|
||||||
if (avaliableRxWindow < f.data.size())
|
if (avaliableRxWindow < f.data.size())
|
||||||
{
|
{
|
||||||
killConnection(streamId,
|
killConnection(streamId,
|
||||||
@ -1126,15 +1379,13 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
|||||||
stream.body.readableBytes() != *stream.contentLength)
|
stream.body.readableBytes() != *stream.contentLength)
|
||||||
{
|
{
|
||||||
LOG_TRACE << "Content-length mismatch";
|
LOG_TRACE << "Content-length mismatch";
|
||||||
streamFinished(streamId,
|
responseErrored(streamId, ReqResult::BadResponse);
|
||||||
ReqResult::BadResponse,
|
|
||||||
StreamCloseErrorCode::ProtocolError);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// TODO: Optmize setting body
|
// TODO: Optmize setting body
|
||||||
std::string body(stream.body.peek(), stream.body.readableBytes());
|
std::string body(stream.body.peek(), stream.body.readableBytes());
|
||||||
stream.response->setBody(std::move(body));
|
stream.response->setBody(std::move(body));
|
||||||
streamFinished(stream);
|
responseSuccess(stream);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1150,6 +1401,55 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
|
|||||||
{
|
{
|
||||||
auto &f = std::get<WindowUpdateFrame>(frame);
|
auto &f = std::get<WindowUpdateFrame>(frame);
|
||||||
stream.avaliableTxWindow += f.windowSizeIncrement;
|
stream.avaliableTxWindow += f.windowSizeIncrement;
|
||||||
|
if (avaliableTxWindow == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto it = pendingDataSend.find(streamId);
|
||||||
|
if (it == pendingDataSend.end())
|
||||||
|
return;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
size_t sendOffset = it->second;
|
||||||
|
assert(stream.request != nullptr);
|
||||||
|
assert(stream.request->body().length() > sendOffset);
|
||||||
|
size_t maxSendSize = stream.request->body().length() - sendOffset;
|
||||||
|
maxSendSize = (std::min)(maxSendSize, stream.avaliableTxWindow);
|
||||||
|
maxSendSize = (std::min)(maxSendSize, avaliableTxWindow);
|
||||||
|
bool sendEverything =
|
||||||
|
maxSendSize == stream.request->body().length() - sendOffset;
|
||||||
|
for (i = 0; i < maxSendSize; i += maxFrameSize)
|
||||||
|
{
|
||||||
|
size_t readSize =
|
||||||
|
(std::min)(maxFrameSize,
|
||||||
|
stream.request->body().length() - sendOffset - i);
|
||||||
|
std::vector<uint8_t> buffer;
|
||||||
|
buffer.resize(readSize);
|
||||||
|
memcpy(buffer.data(),
|
||||||
|
stream.request->body().data() + sendOffset + i,
|
||||||
|
readSize);
|
||||||
|
DataFrame dataFrame;
|
||||||
|
dataFrame.data = std::move(buffer);
|
||||||
|
dataFrame.endStream =
|
||||||
|
(i + maxFrameSize >=
|
||||||
|
stream.request->body().length() - sendOffset);
|
||||||
|
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
|
||||||
|
<< " endStream=" << dataFrame.endStream;
|
||||||
|
connPtr->send(serializeFrame(dataFrame, streamId));
|
||||||
|
|
||||||
|
stream.avaliableTxWindow -= dataFrame.data.size();
|
||||||
|
avaliableTxWindow -= dataFrame.data.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sendEverything)
|
||||||
|
pendingDataSend.erase(it);
|
||||||
|
else
|
||||||
|
it->second = sendOffset + i;
|
||||||
|
}
|
||||||
|
else if (std::holds_alternative<RstStreamFrame>(frame))
|
||||||
|
{
|
||||||
|
auto &f = std::get<RstStreamFrame>(frame);
|
||||||
|
LOG_TRACE << "RST_STREAM frame received: errorCode=" << f.errorCode;
|
||||||
|
responseErrored(streamId, ReqResult::BadResponse);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -1170,7 +1470,7 @@ internal::H2Stream &Http2Transport::createStream(int32_t streamId)
|
|||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Http2Transport::streamFinished(internal::H2Stream &stream)
|
void Http2Transport::responseSuccess(internal::H2Stream &stream)
|
||||||
{
|
{
|
||||||
assert(stream.request != nullptr);
|
assert(stream.request != nullptr);
|
||||||
assert(stream.callback);
|
assert(stream.callback);
|
||||||
@ -1182,6 +1482,10 @@ void Http2Transport::streamFinished(internal::H2Stream &stream)
|
|||||||
respCallback(stream.response, {stream.request, stream.callback}, connPtr);
|
respCallback(stream.response, {stream.request, stream.callback}, connPtr);
|
||||||
streams.erase(it);
|
streams.erase(it);
|
||||||
|
|
||||||
|
auto it2 = pendingDataSend.find(streamId);
|
||||||
|
if (it2 != pendingDataSend.end())
|
||||||
|
pendingDataSend.erase(it2);
|
||||||
|
|
||||||
if (bufferedRequests.empty())
|
if (bufferedRequests.empty())
|
||||||
return;
|
return;
|
||||||
auto &[req, cb] = bufferedRequests.front();
|
auto &[req, cb] = bufferedRequests.front();
|
||||||
@ -1189,15 +1493,17 @@ void Http2Transport::streamFinished(internal::H2Stream &stream)
|
|||||||
bufferedRequests.pop();
|
bufferedRequests.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Http2Transport::streamFinished(int32_t streamId,
|
void Http2Transport::responseErrored(int32_t streamId, ReqResult result)
|
||||||
ReqResult result,
|
|
||||||
StreamCloseErrorCode errorCode)
|
|
||||||
{
|
{
|
||||||
auto it = streams.find(streamId);
|
auto it = streams.find(streamId);
|
||||||
assert(it != streams.end());
|
assert(it != streams.end());
|
||||||
it->second.callback(result, nullptr);
|
it->second.callback(result, nullptr);
|
||||||
streams.erase(it);
|
streams.erase(it);
|
||||||
|
|
||||||
|
auto it2 = pendingDataSend.find(streamId);
|
||||||
|
if (it2 != pendingDataSend.end())
|
||||||
|
pendingDataSend.erase(it2);
|
||||||
|
|
||||||
if (bufferedRequests.empty())
|
if (bufferedRequests.empty())
|
||||||
return;
|
return;
|
||||||
auto &[req, cb] = bufferedRequests.front();
|
auto &[req, cb] = bufferedRequests.front();
|
||||||
@ -1218,6 +1524,8 @@ void Http2Transport::onError(ReqResult result)
|
|||||||
cb(result, nullptr);
|
cb(result, nullptr);
|
||||||
bufferedRequests.pop();
|
bufferedRequests.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pendingDataSend.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Http2Transport::killConnection(int32_t lastStreamId,
|
void Http2Transport::killConnection(int32_t lastStreamId,
|
||||||
@ -1239,5 +1547,7 @@ bool Http2Transport::handleConnectionClose()
|
|||||||
return false;
|
return false;
|
||||||
for (auto &[streamId, stream] : streams)
|
for (auto &[streamId, stream] : streams)
|
||||||
stream.callback(ReqResult::BadResponse, nullptr);
|
stream.callback(ReqResult::BadResponse, nullptr);
|
||||||
|
streams.clear();
|
||||||
|
pendingDataSend.clear();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
@ -78,12 +78,46 @@ struct PingFrame
|
|||||||
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ContinuationFrame
|
||||||
|
{
|
||||||
|
std::vector<uint8_t> headerBlockFragment;
|
||||||
|
bool endHeaders = false;
|
||||||
|
|
||||||
|
static std::optional<ContinuationFrame> parse(ByteStream &payload,
|
||||||
|
uint8_t flags);
|
||||||
|
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RstStreamFrame
|
||||||
|
{
|
||||||
|
uint32_t errorCode = 0;
|
||||||
|
|
||||||
|
static std::optional<RstStreamFrame> parse(ByteStream &payload,
|
||||||
|
uint8_t flags);
|
||||||
|
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PushPromiseFrame
|
||||||
|
{
|
||||||
|
uint8_t padLength = 0;
|
||||||
|
bool endHeaders = false;
|
||||||
|
int32_t promisedStreamId = 0;
|
||||||
|
std::vector<uint8_t> headerBlockFragment;
|
||||||
|
|
||||||
|
static std::optional<PushPromiseFrame> parse(ByteStream &payload,
|
||||||
|
uint8_t flags);
|
||||||
|
bool serialize(OByteStream &stream, uint8_t &flags) const;
|
||||||
|
};
|
||||||
|
|
||||||
using H2Frame = std::variant<SettingsFrame,
|
using H2Frame = std::variant<SettingsFrame,
|
||||||
WindowUpdateFrame,
|
WindowUpdateFrame,
|
||||||
HeadersFrame,
|
HeadersFrame,
|
||||||
GoAwayFrame,
|
GoAwayFrame,
|
||||||
DataFrame,
|
DataFrame,
|
||||||
PingFrame>;
|
PingFrame,
|
||||||
|
ContinuationFrame,
|
||||||
|
PushPromiseFrame,
|
||||||
|
RstStreamFrame>;
|
||||||
|
|
||||||
enum class StreamState
|
enum class StreamState
|
||||||
{
|
{
|
||||||
@ -138,11 +172,13 @@ class Http2Transport : public HttpTransport
|
|||||||
hpack::HPacker hpackTx;
|
hpack::HPacker hpackTx;
|
||||||
hpack::HPacker hpackRx;
|
hpack::HPacker hpackRx;
|
||||||
|
|
||||||
std::priority_queue<int32_t> usibleStreamIds;
|
|
||||||
int32_t currentStreamId = 1;
|
int32_t currentStreamId = 1;
|
||||||
std::unordered_map<int32_t, internal::H2Stream> streams;
|
std::unordered_map<int32_t, internal::H2Stream> streams;
|
||||||
bool serverSettingsReceived = false;
|
|
||||||
std::queue<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
|
std::queue<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
|
||||||
|
trantor::MsgBuffer headerBufferRx;
|
||||||
|
int32_t expectngContinuationStreamId = 0;
|
||||||
|
|
||||||
|
std::unordered_map<int32_t, size_t> pendingDataSend;
|
||||||
// TODO: Handle server-initiated stream creation
|
// TODO: Handle server-initiated stream creation
|
||||||
|
|
||||||
// HTTP/2 client-wide settings (can be changed by server)
|
// HTTP/2 client-wide settings (can be changed by server)
|
||||||
@ -160,10 +196,8 @@ class Http2Transport : public HttpTransport
|
|||||||
size_t avaliableRxWindow = 65535;
|
size_t avaliableRxWindow = 65535;
|
||||||
|
|
||||||
internal::H2Stream &createStream(int32_t streamId);
|
internal::H2Stream &createStream(int32_t streamId);
|
||||||
void streamFinished(internal::H2Stream &stream);
|
void responseSuccess(internal::H2Stream &stream);
|
||||||
void streamFinished(int32_t streamId,
|
void responseErrored(int32_t streamId, ReqResult result);
|
||||||
ReqResult result,
|
|
||||||
StreamCloseErrorCode errorCode);
|
|
||||||
|
|
||||||
int32_t nextStreamId()
|
int32_t nextStreamId()
|
||||||
{
|
{
|
||||||
@ -181,6 +215,10 @@ class Http2Transport : public HttpTransport
|
|||||||
StreamCloseErrorCode errorCode,
|
StreamCloseErrorCode errorCode,
|
||||||
std::string errorMsg = "");
|
std::string errorMsg = "");
|
||||||
|
|
||||||
|
bool parseAndApplyHeaders(internal::H2Stream &stream,
|
||||||
|
const void *data,
|
||||||
|
size_t len);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Http2Transport(trantor::TcpConnectionPtr connPtr,
|
Http2Transport(trantor::TcpConnectionPtr connPtr,
|
||||||
size_t *bytesSent,
|
size_t *bytesSent,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user