Compare commits

..

16 Commits

Author SHA1 Message Date
Martin Chang
52d0fdd25d format 2023-11-10 11:18:24 +08:00
Martin Chang
0eb2cdabe7 rfc compliant and overflow fix 2023-11-10 10:56:13 +08:00
Martin Chang
e12260a0b7 safer string parsing 2023-11-10 10:31:30 +08:00
Martin Chang
c82e8208ab remove unneeded delay 2023-11-10 10:25:13 +08:00
Martin Chang
e13a8b930f handle sending large body 2023-11-10 10:17:51 +08:00
Martin Chang
4a2eecf03d Handle sending CONTINUATION frame 2023-11-10 09:46:22 +08:00
Martin Chang
45a2b1d0d3 react to RST_STREAM 2023-11-10 09:36:04 +08:00
Martin Chang
d69699ceb2 add check for DATA, HEADERS and PUSH_PROMISE frames 2023-11-09 23:04:25 +08:00
Martin Chang
5f52a80358 able to parse RST_STREAM 2023-11-09 16:07:19 +08:00
Martin Chang
ab566b3524 kill connection when PUSH_PROMIS recv 2023-11-09 15:57:08 +08:00
Martin Chang
e565f38d7a fix unable to post any data 2023-11-09 15:45:01 +08:00
Martin Chang
5d434577ff fix untrue error message 2023-11-09 15:42:37 +08:00
Martin Chang
2eef512537 parse incomming headers in continuation frame 2023-11-09 15:05:17 +08:00
Martin Chang
847b580cf5 wip: store headers if we need more frames 2023-11-09 11:52:55 +08:00
Martin Chang
83606eb5a6 slight optimization 2023-11-09 11:24:22 +08:00
Martin Chang
f5c4863ad0 Sementics cleanup 2023-11-09 11:22:48 +08:00
3 changed files with 498 additions and 150 deletions

View File

@ -16,7 +16,7 @@ int main()
{
trantor::Logger::setLogLevel(trantor::Logger::kTrace);
{
auto client = HttpClient::newHttpClient("https://clehaxze.tw",
auto client = HttpClient::newHttpClient("https://clehaxze.tw:8844",
nullptr,
false,
false);

View File

@ -18,7 +18,11 @@ static std::optional<size_t> stosz(const std::string &str)
{
try
{
return std::stoull(str);
size_t idx = 0;
size_t res = std::stoull(str, &idx);
if (idx != str.size())
return std::nullopt;
return res;
}
catch (const std::exception &)
{
@ -107,7 +111,7 @@ struct ByteStream
uint32_t readU24BE()
{
assert(offset <= length - 3);
assert(length >= 3 && offset <= length - 3);
uint32_t res =
ptr[offset] << 16 | ptr[offset + 1] << 8 | ptr[offset + 2];
offset += 3;
@ -116,7 +120,7 @@ struct ByteStream
uint32_t readU32BE()
{
assert(offset <= length - 4);
assert(length >= 4 && offset <= length - 4);
uint32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
ptr[offset + 2] << 8 | ptr[offset + 3];
offset += 4;
@ -125,7 +129,7 @@ struct ByteStream
std::pair<bool, int32_t> readBI32BE()
{
assert(offset <= length - 4);
assert(length >= 4 && offset <= length - 4);
int32_t res = ptr[offset] << 24 | ptr[offset + 1] << 16 |
ptr[offset + 2] << 8 | ptr[offset + 3];
offset += 4;
@ -137,7 +141,7 @@ struct ByteStream
uint16_t readU16BE()
{
assert(offset <= length - 2);
assert(length >= 2 && offset <= length - 2);
uint16_t res = ptr[offset] << 8 | ptr[offset + 1];
offset += 2;
return res;
@ -145,13 +149,13 @@ struct ByteStream
uint8_t readU8()
{
assert(offset <= length - 1);
assert(length >= 1 && offset <= length - 1);
return ptr[offset++];
}
void read(uint8_t *buffer, size_t size)
{
assert(offset <= length - size || size == 0);
assert((length >= size && offset <= length - size) || size == 0);
memcpy(buffer, ptr + offset, size);
offset += size;
}
@ -171,7 +175,7 @@ struct ByteStream
void skip(size_t n)
{
assert(offset <= length - n || n == 0);
assert((length >= n && offset <= length - n) || n == 0);
offset += n;
}
@ -227,6 +231,7 @@ struct OByteStream
{
assert(value <= 0xffffff);
assert(offset <= buffer.readableBytes() - 3);
assert(buffer.writableBytes() >= 3);
auto ptr = (uint8_t *)buffer.peek() + offset;
ptr[0] = value >> 16;
ptr[1] = value >> 8;
@ -236,6 +241,7 @@ struct OByteStream
void overwriteU8(size_t offset, uint8_t value)
{
assert(offset <= buffer.readableBytes() - 1);
assert(buffer.writableBytes() >= 1);
auto ptr = (uint8_t *)buffer.peek() + offset;
ptr[0] = value;
}
@ -325,11 +331,25 @@ std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
bool priority = flags & (uint8_t)H2HeadersFlags::Priority;
if (payload.size() == 0)
{
LOG_TRACE << "Header size cannot be 0";
return std::nullopt;
}
HeadersFrame frame;
if (padded)
{
frame.padLength = payload.readU8();
}
size_t minSize = frame.padLength + (priority ? 5 : 0);
if (payload.size() < minSize)
{
LOG_TRACE << "Invalid headers frame length";
return std::nullopt;
}
if (priority)
{
auto [exclusive, streamDependency] = payload.readBI32BE();
@ -346,6 +366,7 @@ std::optional<HeadersFrame> HeadersFrame::parse(ByteStream &payload,
frame.endStream = true;
}
assert(payload.remaining() >= frame.padLength);
int64_t payloadSize = payload.remaining() - frame.padLength;
if (payloadSize < 0)
{
@ -421,11 +442,20 @@ std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
{
frame.padLength = payload.readU8();
}
size_t minSize = frame.padLength;
if (payload.size() < minSize)
{
LOG_TRACE << "Invalid data frame length";
return std::nullopt;
}
if (endStream)
{
frame.endStream = true;
}
assert(payload.remaining() >= frame.padLength);
size_t payloadSize = payload.remaining() - frame.padLength;
if (payloadSize < 0)
{
@ -447,7 +477,7 @@ std::optional<DataFrame> DataFrame::parse(ByteStream &payload, uint8_t flags)
bool DataFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
flags = (endStream ? (uint8_t)H2DataFlags::EndStream : 0x0);
stream.write(data.data(), data.size());
if (padLength > 0)
{
@ -477,6 +507,95 @@ bool PingFrame::serialize(OByteStream &stream, uint8_t &flags) const
return true;
}
std::optional<ContinuationFrame> ContinuationFrame::parse(ByteStream &payload,
uint8_t flags)
{
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
ContinuationFrame frame;
if (endHeaders)
{
frame.endHeaders = true;
}
frame.headerBlockFragment.resize(payload.remaining());
payload.read(frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
return frame;
}
bool ContinuationFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
if (endHeaders)
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
return true;
}
std::optional<PushPromiseFrame> PushPromiseFrame::parse(ByteStream &payload,
uint8_t flags)
{
bool endHeaders = flags & (uint8_t)H2HeadersFlags::EndHeaders;
bool padded = flags & (uint8_t)H2HeadersFlags::Padded;
PushPromiseFrame frame;
if (padded)
frame.padLength = payload.readU8();
if (endHeaders)
frame.endHeaders = true;
size_t minSize = frame.padLength + 4;
if (payload.size() < minSize)
{
LOG_TRACE << "Invalid push promise frame length";
return std::nullopt;
}
auto [_, promisedStreamId] = payload.readBI32BE();
frame.promisedStreamId = promisedStreamId;
assert(payload.remaining() >= frame.padLength);
frame.headerBlockFragment.resize(payload.remaining() - frame.padLength);
payload.read(frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
payload.skip(frame.padLength);
return frame;
}
bool PushPromiseFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
if (endHeaders)
flags |= (uint8_t)H2HeadersFlags::EndHeaders;
assert(promisedStreamId > 0);
stream.writeU32BE(promisedStreamId);
stream.write(headerBlockFragment.data(), headerBlockFragment.size());
if (padLength > 0)
{
flags |= (uint8_t)H2HeadersFlags::Padded;
stream.writeU8(padLength);
}
return true;
}
std::optional<RstStreamFrame> RstStreamFrame::parse(ByteStream &payload,
uint8_t flags)
{
if (payload.size() != 4)
{
LOG_TRACE << "Invalid RST_STREAM frame length";
return std::nullopt;
}
RstStreamFrame frame;
frame.errorCode = payload.readU32BE();
return frame;
}
bool RstStreamFrame::serialize(OByteStream &stream, uint8_t &flags) const
{
flags = 0x0;
stream.writeU32BE(errorCode);
return true;
}
} // namespace drogon::internal
// Print the HEX and ASCII representation of the buffer side by side
@ -568,6 +687,24 @@ static trantor::MsgBuffer serializeFrame(const H2Frame &frame, int32_t streamId)
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::Ping;
}
else if (std::holds_alternative<ContinuationFrame>(frame))
{
const auto &f = std::get<ContinuationFrame>(frame);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::Continuation;
}
else if (std::holds_alternative<PushPromiseFrame>(frame))
{
const auto &f = std::get<PushPromiseFrame>(frame);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::PushPromise;
}
else if (std::holds_alternative<RstStreamFrame>(frame))
{
const auto &f = std::get<RstStreamFrame>(frame);
ok = f.serialize(buffer, flags);
type = (uint8_t)H2FrameType::RstStream;
}
else
{
LOG_ERROR << "Unsupported frame type";
@ -646,6 +783,12 @@ static std::tuple<std::optional<H2Frame>, uint32_t, uint8_t, bool> parseH2Frame(
frame = DataFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::Ping)
frame = PingFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::Continuation)
frame = ContinuationFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::PushPromise)
frame = PushPromiseFrame::parse(payload, flags);
else if (type == (uint8_t)H2FrameType::RstStream)
frame = RstStreamFrame::parse(payload, flags);
else
{
LOG_WARN << "Unsupported H2 frame type: " << (int)type;
@ -670,13 +813,7 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
HttpReqCallback &&callback)
{
connPtr->getLoop()->assertInLoopThread();
if (!serverSettingsReceived)
{
bufferedRequests.push({req, std::move(callback)});
return;
}
if (streams.size() >= maxConcurrentStreams)
if (streams.size() + 1 >= maxConcurrentStreams)
{
LOG_TRACE << "Too many streams in flight. Buffering request";
bufferedRequests.push({req, std::move(callback)});
@ -697,15 +834,11 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
}
auto headers = req->headers();
HeadersFrame frame;
frame.padLength = 0;
frame.exclusive = false;
frame.streamDependency = 0;
frame.weight = 0;
frame.headerBlockFragment.resize(maxCompressiedHeaderSize);
std::vector<uint8_t> encodedHeaders(maxCompressiedHeaderSize);
LOG_TRACE << "Sending HTTP/2 headers: size=" << headers.size();
hpack::HPacker::KeyValueVector headersToEncode;
headersToEncode.reserve(headers.size() + 5);
const std::array<std::string_view, 2> headersToSkip = {
{"host", "connection"}};
headersToEncode.emplace_back(":method", req->methodString());
@ -726,8 +859,8 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
for (auto &[key, value] : headersToEncode)
LOG_TRACE << " " << key << ": " << value;
int n = hpackTx.encode(headersToEncode,
frame.headerBlockFragment.data(),
frame.headerBlockFragment.size());
encodedHeaders.data(),
encodedHeaders.size());
if (n < 0)
{
LOG_TRACE << "Failed to encode headers. Internal error or header "
@ -742,17 +875,54 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
abort();
return;
}
frame.headerBlockFragment.resize(n);
frame.endHeaders = true;
encodedHeaders.resize(n);
LOG_TRACE << "Encoded headers size: " << encodedHeaders.size();
bool haveBody = req->body().length() > 0;
auto &stream = createStream(streamId);
if (req->body().length() == 0)
frame.endStream = true;
LOG_TRACE << "Sending headers frame";
auto f = serializeFrame(frame, streamId);
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
connPtr->send(f);
bool needsContinuation = encodedHeaders.size() > maxFrameSize;
for (size_t i = 0; i < encodedHeaders.size(); i += maxFrameSize)
{
bool isFirst = i == 0;
bool isLast = i + maxFrameSize >= encodedHeaders.size();
size_t dataSize = (std::min)(maxFrameSize, encodedHeaders.size() - i);
auto frame = [&]() -> H2Frame {
if (isFirst)
{
HeadersFrame frame;
frame.headerBlockFragment.resize(dataSize);
memcpy(frame.headerBlockFragment.data(),
encodedHeaders.data() + i,
dataSize);
frame.endHeaders = isLast;
frame.endStream = (!haveBody && isLast);
return frame;
}
ContinuationFrame frame;
frame.headerBlockFragment.resize(dataSize);
assert(encodedHeaders.size() > i + dataSize);
memcpy(frame.headerBlockFragment.data(),
encodedHeaders.data() + i,
dataSize);
frame.endHeaders = (!haveBody && isLast);
return frame;
}();
auto f = serializeFrame(frame, streamId);
LOG_TRACE << "Sending " << (isFirst ? "HEADERS" : "CONTINUATION")
<< " frame:";
LOG_TRACE << dump_hex_beautiful(f.peek(), f.readableBytes());
connPtr->send(serializeFrame(frame, streamId));
}
if (needsContinuation && !haveBody)
{
DataFrame frame;
frame.endStream = true;
connPtr->send(serializeFrame(frame, streamId));
return;
}
stream.callback = std::move(callback);
stream.request = req;
@ -763,29 +933,41 @@ void Http2Transport::sendRequestInLoop(const HttpRequestPtr &req,
return;
}
if (req->body().length() > stream.avaliableTxWindow)
{
LOG_ERROR << "HTTP/2 body too large to fit in INITIAL_WINDOW_SIZE. Not "
"supported yet.";
abort();
return;
}
size_t bodySize = req->body().length();
bool sendEverything =
bodySize <= stream.avaliableTxWindow && bodySize <= avaliableTxWindow;
size_t maxSendSize = bodySize;
maxSendSize = (std::min)(maxSendSize, stream.avaliableTxWindow);
maxSendSize = (std::min)(maxSendSize, avaliableTxWindow);
DataFrame dataFrame;
for (size_t i = 0; i < req->body().length(); i += maxFrameSize)
size_t i;
for (i = 0; i < maxSendSize; i += maxFrameSize)
{
size_t readSize = (std::min)(maxFrameSize, req->body().length() - i);
size_t readSize = (std::min)(maxFrameSize, bodySize - i);
std::vector<uint8_t> buffer;
buffer.resize(readSize);
memcpy(buffer.data(), req->body().data() + i, readSize);
dataFrame.data = std::move(buffer);
dataFrame.endStream = (i + maxFrameSize >= req->body().length());
dataFrame.endStream = (i + maxFrameSize >= bodySize);
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
<< " endStream=" << dataFrame.endStream;
connPtr->send(serializeFrame(dataFrame, streamId));
stream.avaliableTxWindow -= dataFrame.data.size();
avaliableRxWindow -= dataFrame.data.size();
avaliableTxWindow -= dataFrame.data.size();
}
if (!sendEverything)
{
auto it = pendingDataSend.find(streamId);
if (it != pendingDataSend.end())
{
LOG_FATAL << "Stream id already in use! This should not happen";
abort();
}
pendingDataSend.emplace(streamId, i);
}
}
@ -806,6 +988,9 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
avaliableRxWindow += windowIncreaseSize;
}
if (msg->readableBytes() == 0)
break;
// FIXME: The code cannot distinguish between a out-of-data and
// unsupported frame type. We need to fix this as it should be handled
// differently.
@ -825,6 +1010,7 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
}
auto &frame = *frameOpt;
// special case for PING and GOAWAY. These are all global frames
if (std::holds_alternative<GoAwayFrame>(frame))
{
auto &f = std::get<GoAwayFrame>(frame);
@ -840,18 +1026,13 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
{
if (streamId > f.lastStreamId)
{
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::RefusedStream);
responseErrored(streamId, ReqResult::BadResponse);
}
}
// TODO: Should be half-closed but transport doesn't support it yet
connPtr->shutdown();
}
// special case for PING frame. It is the only frame that is not
// associated with a stream
if (std::holds_alternative<PingFrame>(frame))
else if (std::holds_alternative<PingFrame>(frame))
{
auto &f = std::get<PingFrame>(frame);
if (f.ack)
@ -867,6 +1048,32 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
continue;
}
// If we are expecting a CONTINUATION frame, we should not receive
// HEADERS or CONTINUATION from other streams
if (expectngContinuationStreamId != 0 &&
(std::holds_alternative<HeadersFrame>(frame) ||
(std::holds_alternative<ContinuationFrame>(frame) &&
streamId != expectngContinuationStreamId)))
{
LOG_TRACE << "Protocol error: unexpected HEADERS or "
"CONTINUATION frame";
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"Expecting CONTINUATION frame for stream " +
std::to_string(expectngContinuationStreamId));
return;
}
if (std::holds_alternative<PushPromiseFrame>(frame))
{
LOG_TRACE << "Push promise frame received. Not supported yet. "
"Connection will die";
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"Push promise not supported");
return;
}
if (streamId != 0)
{
handleFrameForStream(frame, streamId, flags);
@ -878,6 +1085,14 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
{
auto &f = std::get<WindowUpdateFrame>(frame);
avaliableTxWindow += f.windowSizeIncrement;
// HACK: Notify stream we have more window size available
auto it = pendingDataSend.begin();
if (it == pendingDataSend.end())
continue;
auto hackFrame = f;
hackFrame.windowSizeIncrement = 0;
handleFrameForStream(hackFrame, it->first, 0);
}
else if (std::holds_alternative<SettingsFrame>(frame))
{
@ -888,7 +1103,7 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
{
hpackRx.setMaxTableSize(value);
}
if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
else if (key == (uint16_t)H2SettingsKey::MaxConcurrentStreams)
{
// Note: MAX_CONCURRENT_STREAMS can be 0, which means
// the client is not allowed to send any request. I doubt
@ -927,27 +1142,6 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
LOG_TRACE << "Unsupported settings key: " << key;
}
}
if (!serverSettingsReceived)
{
LOG_TRACE
<< "Server settings received. Sending our own WindowUpdate";
WindowUpdateFrame windowUpdateFrame;
windowUpdateFrame.windowSizeIncrement = windowIncreaseSize;
connPtr->send(serializeFrame(windowUpdateFrame, 0));
avaliableRxWindow = initialWindowSize;
serverSettingsReceived = true;
while (!bufferedRequests.empty() &&
streams.size() < maxConcurrentStreams)
{
auto &[req, cb] = bufferedRequests.front();
sendRequestInLoop(req, std::move(cb));
bufferedRequests.pop();
}
}
// Somehow nghttp2 wants us to send ACK after sending our
// preferences??
if (flags == 1)
@ -961,12 +1155,30 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
{
// Should never show up on stream 0
LOG_FATAL << "Protocol error: HEADERS frame on stream 0";
errorCallback(ReqResult::BadResponse);
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"HEADERS frame on stream 0");
}
else if (std::holds_alternative<DataFrame>(frame))
{
LOG_FATAL << "Protocol error: DATA frame on stream 0";
errorCallback(ReqResult::BadResponse);
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"DATA frame on stream 0");
}
else if (std::holds_alternative<ContinuationFrame>(frame))
{
LOG_FATAL << "Protocol error: CONTINUATION frame on stream 0";
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"CONTINUATION frame on stream 0");
}
else if (std::holds_alternative<RstStreamFrame>(frame))
{
LOG_FATAL << "Protocol error: RST_STREAM frame on stream 0";
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"RST_STREAM frame on stream 0");
}
else
{
@ -975,6 +1187,65 @@ void Http2Transport::onRecvMessage(const trantor::TcpConnectionPtr &,
}
}
bool Http2Transport::parseAndApplyHeaders(internal::H2Stream &stream,
const void *data,
size_t size)
{
hpack::HPacker::KeyValueVector headers;
int n = hpackRx.decode((const uint8_t *)data, size, headers);
auto streamId = stream.streamId;
if (n < 0)
{
LOG_TRACE << "Failed to decode headers";
killConnection(streamId,
StreamCloseErrorCode::CompressionError,
"Failed to decode headers");
return false;
}
for (auto &[key, value] : headers)
LOG_TRACE << " " << key << ": " << value;
assert(stream.response == nullptr);
stream.response = std::make_shared<HttpResponseImpl>();
for (const auto &[key, value] : headers)
{
// TODO: Filter more pseudo headers
if (key == "content-length")
{
auto sz = stosz(value);
if (!sz)
{
responseErrored(streamId, ReqResult::BadResponse);
return false;
}
stream.contentLength = std::move(sz);
}
if (key == ":status")
{
auto status = stosz(value);
if (!status)
{
responseErrored(streamId, ReqResult::BadResponse);
return false;
}
// TODO: Validate status code
stream.response->setStatusCode((HttpStatusCode)*status);
continue;
}
// Anti request smuggling. We look for \r or \n in the header
// name or value. If we find one, we abort the stream.
if (key.find_first_of("\r\n") != std::string::npos ||
value.find_first_of("\r\n") != std::string::npos)
{
responseErrored(streamId, ReqResult::BadResponse);
return false;
}
stream.response->addHeader(key, value);
}
return true;
}
Http2Transport::Http2Transport(trantor::TcpConnectionPtr connPtr,
size_t *bytesSent,
size_t *bytesReceived)
@ -1011,6 +1282,16 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
if (std::holds_alternative<HeadersFrame>(frame))
{
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
{
auto &f = std::get<HeadersFrame>(frame);
headerBufferRx.append((char *)f.headerBlockFragment.data(),
f.headerBlockFragment.size());
expectngContinuationStreamId = streamId;
stream.state = StreamState::ExpectingContinuation;
return;
}
if (stream.state != StreamState::ExpectingHeaders)
{
killConnection(streamId,
@ -1019,79 +1300,51 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
return;
}
auto &f = std::get<HeadersFrame>(frame);
LOG_TRACE << "Headers frame received: size="
<< f.headerBlockFragment.size();
hpack::HPacker::KeyValueVector headers;
int n = hpackRx.decode(f.headerBlockFragment.data(),
f.headerBlockFragment.size(),
headers);
if (n < 0)
{
LOG_TRACE << "Failed to decode headers";
killConnection(streamId,
StreamCloseErrorCode::CompressionError,
"Failed to decode headers");
// This function handles error itself
if (!parseAndApplyHeaders(stream,
f.headerBlockFragment.data(),
f.headerBlockFragment.size()))
return;
}
for (auto &[key, value] : headers)
LOG_TRACE << " " << key << ": " << value;
it->second.response = std::make_shared<HttpResponseImpl>();
for (const auto &[key, value] : headers)
{
// TODO: Filter more pseudo headers
if (key == "content-length")
{
auto sz = stosz(value);
if (!sz)
{
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"Invalid content-length");
return;
}
it->second.contentLength = std::move(sz);
}
if (key == ":status")
{
// TODO: Validate status code
it->second.response->setStatusCode(
(drogon::HttpStatusCode)std::stoi(value));
continue;
}
// Anti request smuggling. We look for \r or \n in the header
// name or value. If we find one, we abort the stream.
if (key.find_first_of("\r\n") != std::string::npos ||
value.find_first_of("\r\n") != std::string::npos)
{
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError);
return;
}
it->second.response->addHeader(key, value);
}
if ((flags & (uint8_t)H2HeadersFlags::EndHeaders) == 0)
{
LOG_TRACE << "We don't support CONTINUATION frames yet!";
stream.state = StreamState::ExpectingContinuation;
abort();
}
// There is no body in the response.
if ((flags & (uint8_t)H2HeadersFlags::EndStream))
{
stream.state = StreamState::Finished;
streamFinished(stream);
responseSuccess(stream);
return;
}
stream.state = StreamState::ExpectingData;
}
else if (std::holds_alternative<ContinuationFrame>(frame))
{
auto &f = std::get<ContinuationFrame>(frame);
if (stream.state != StreamState::ExpectingContinuation)
{
killConnection(streamId,
StreamCloseErrorCode::ProtocolError,
"Unexpected continuation frame");
return;
}
headerBufferRx.append((char *)f.headerBlockFragment.data(),
f.headerBlockFragment.size());
bool endHeaders = (flags & (uint8_t)H2HeadersFlags::EndHeaders) != 0;
if (endHeaders)
{
stream.state = StreamState::ExpectingData;
expectngContinuationStreamId = 0;
bool ok = parseAndApplyHeaders(stream,
headerBufferRx.peek(),
headerBufferRx.readableBytes());
headerBufferRx.retrieveAll();
if (!ok)
LOG_TRACE << "Failed to parse headers in continuation frame";
return;
}
}
else if (std::holds_alternative<DataFrame>(frame))
{
auto &f = std::get<DataFrame>(frame);
// TODO: Make sure this logic fits RFC
if (avaliableRxWindow < f.data.size())
{
killConnection(streamId,
@ -1126,15 +1379,13 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
stream.body.readableBytes() != *stream.contentLength)
{
LOG_TRACE << "Content-length mismatch";
streamFinished(streamId,
ReqResult::BadResponse,
StreamCloseErrorCode::ProtocolError);
responseErrored(streamId, ReqResult::BadResponse);
return;
}
// TODO: Optmize setting body
std::string body(stream.body.peek(), stream.body.readableBytes());
stream.response->setBody(std::move(body));
streamFinished(stream);
responseSuccess(stream);
return;
}
@ -1150,6 +1401,55 @@ void Http2Transport::handleFrameForStream(const internal::H2Frame &frame,
{
auto &f = std::get<WindowUpdateFrame>(frame);
stream.avaliableTxWindow += f.windowSizeIncrement;
if (avaliableTxWindow == 0)
return;
auto it = pendingDataSend.find(streamId);
if (it == pendingDataSend.end())
return;
size_t i = 0;
size_t sendOffset = it->second;
assert(stream.request != nullptr);
assert(stream.request->body().length() > sendOffset);
size_t maxSendSize = stream.request->body().length() - sendOffset;
maxSendSize = (std::min)(maxSendSize, stream.avaliableTxWindow);
maxSendSize = (std::min)(maxSendSize, avaliableTxWindow);
bool sendEverything =
maxSendSize == stream.request->body().length() - sendOffset;
for (i = 0; i < maxSendSize; i += maxFrameSize)
{
size_t readSize =
(std::min)(maxFrameSize,
stream.request->body().length() - sendOffset - i);
std::vector<uint8_t> buffer;
buffer.resize(readSize);
memcpy(buffer.data(),
stream.request->body().data() + sendOffset + i,
readSize);
DataFrame dataFrame;
dataFrame.data = std::move(buffer);
dataFrame.endStream =
(i + maxFrameSize >=
stream.request->body().length() - sendOffset);
LOG_TRACE << "Sending data frame: size=" << dataFrame.data.size()
<< " endStream=" << dataFrame.endStream;
connPtr->send(serializeFrame(dataFrame, streamId));
stream.avaliableTxWindow -= dataFrame.data.size();
avaliableTxWindow -= dataFrame.data.size();
}
if (sendEverything)
pendingDataSend.erase(it);
else
it->second = sendOffset + i;
}
else if (std::holds_alternative<RstStreamFrame>(frame))
{
auto &f = std::get<RstStreamFrame>(frame);
LOG_TRACE << "RST_STREAM frame received: errorCode=" << f.errorCode;
responseErrored(streamId, ReqResult::BadResponse);
}
else
{
@ -1170,7 +1470,7 @@ internal::H2Stream &Http2Transport::createStream(int32_t streamId)
return stream;
}
void Http2Transport::streamFinished(internal::H2Stream &stream)
void Http2Transport::responseSuccess(internal::H2Stream &stream)
{
assert(stream.request != nullptr);
assert(stream.callback);
@ -1182,6 +1482,10 @@ void Http2Transport::streamFinished(internal::H2Stream &stream)
respCallback(stream.response, {stream.request, stream.callback}, connPtr);
streams.erase(it);
auto it2 = pendingDataSend.find(streamId);
if (it2 != pendingDataSend.end())
pendingDataSend.erase(it2);
if (bufferedRequests.empty())
return;
auto &[req, cb] = bufferedRequests.front();
@ -1189,15 +1493,17 @@ void Http2Transport::streamFinished(internal::H2Stream &stream)
bufferedRequests.pop();
}
void Http2Transport::streamFinished(int32_t streamId,
ReqResult result,
StreamCloseErrorCode errorCode)
void Http2Transport::responseErrored(int32_t streamId, ReqResult result)
{
auto it = streams.find(streamId);
assert(it != streams.end());
it->second.callback(result, nullptr);
streams.erase(it);
auto it2 = pendingDataSend.find(streamId);
if (it2 != pendingDataSend.end())
pendingDataSend.erase(it2);
if (bufferedRequests.empty())
return;
auto &[req, cb] = bufferedRequests.front();
@ -1218,6 +1524,8 @@ void Http2Transport::onError(ReqResult result)
cb(result, nullptr);
bufferedRequests.pop();
}
pendingDataSend.clear();
}
void Http2Transport::killConnection(int32_t lastStreamId,
@ -1239,5 +1547,7 @@ bool Http2Transport::handleConnectionClose()
return false;
for (auto &[streamId, stream] : streams)
stream.callback(ReqResult::BadResponse, nullptr);
streams.clear();
pendingDataSend.clear();
return true;
}

View File

@ -78,12 +78,46 @@ struct PingFrame
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct ContinuationFrame
{
std::vector<uint8_t> headerBlockFragment;
bool endHeaders = false;
static std::optional<ContinuationFrame> parse(ByteStream &payload,
uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct RstStreamFrame
{
uint32_t errorCode = 0;
static std::optional<RstStreamFrame> parse(ByteStream &payload,
uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
struct PushPromiseFrame
{
uint8_t padLength = 0;
bool endHeaders = false;
int32_t promisedStreamId = 0;
std::vector<uint8_t> headerBlockFragment;
static std::optional<PushPromiseFrame> parse(ByteStream &payload,
uint8_t flags);
bool serialize(OByteStream &stream, uint8_t &flags) const;
};
using H2Frame = std::variant<SettingsFrame,
WindowUpdateFrame,
HeadersFrame,
GoAwayFrame,
DataFrame,
PingFrame>;
PingFrame,
ContinuationFrame,
PushPromiseFrame,
RstStreamFrame>;
enum class StreamState
{
@ -138,11 +172,13 @@ class Http2Transport : public HttpTransport
hpack::HPacker hpackTx;
hpack::HPacker hpackRx;
std::priority_queue<int32_t> usibleStreamIds;
int32_t currentStreamId = 1;
std::unordered_map<int32_t, internal::H2Stream> streams;
bool serverSettingsReceived = false;
std::queue<std::pair<HttpRequestPtr, HttpReqCallback>> bufferedRequests;
trantor::MsgBuffer headerBufferRx;
int32_t expectngContinuationStreamId = 0;
std::unordered_map<int32_t, size_t> pendingDataSend;
// TODO: Handle server-initiated stream creation
// HTTP/2 client-wide settings (can be changed by server)
@ -160,10 +196,8 @@ class Http2Transport : public HttpTransport
size_t avaliableRxWindow = 65535;
internal::H2Stream &createStream(int32_t streamId);
void streamFinished(internal::H2Stream &stream);
void streamFinished(int32_t streamId,
ReqResult result,
StreamCloseErrorCode errorCode);
void responseSuccess(internal::H2Stream &stream);
void responseErrored(int32_t streamId, ReqResult result);
int32_t nextStreamId()
{
@ -181,6 +215,10 @@ class Http2Transport : public HttpTransport
StreamCloseErrorCode errorCode,
std::string errorMsg = "");
bool parseAndApplyHeaders(internal::H2Stream &stream,
const void *data,
size_t len);
public:
Http2Transport(trantor::TcpConnectionPtr connPtr,
size_t *bytesSent,