diff --git a/lib/httpx/plugins/compression/brotli.rb b/lib/httpx/plugins/compression/brotli.rb index f69f07f1..dc37e378 100644 --- a/lib/httpx/plugins/compression/brotli.rb +++ b/lib/httpx/plugins/compression/brotli.rb @@ -18,12 +18,13 @@ module HTTPX module Deflater module_function - def deflate(raw, buffer, chunk_size:) + def deflate(raw, buffer = "".b, chunk_size: 16_384) while (chunk = raw.read(chunk_size)) compressed = ::Brotli.deflate(chunk) buffer << compressed yield compressed if block_given? end + buffer end end diff --git a/lib/httpx/plugins/compression/deflate.rb b/lib/httpx/plugins/compression/deflate.rb index 6f871717..ed45a5a1 100644 --- a/lib/httpx/plugins/compression/deflate.rb +++ b/lib/httpx/plugins/compression/deflate.rb @@ -17,7 +17,7 @@ module HTTPX module Deflater module_function - def deflate(raw, buffer, chunk_size:) + def deflate(raw, buffer = "".b, chunk_size: 16_384) deflater = Zlib::Deflate.new while (chunk = raw.read(chunk_size)) compressed = deflater.deflate(chunk) @@ -27,6 +27,7 @@ module HTTPX last = deflater.finish buffer << last yield last if block_given? + buffer ensure deflater.close if deflater end diff --git a/lib/httpx/plugins/compression/gzip.rb b/lib/httpx/plugins/compression/gzip.rb index e6fc5d31..56e06a37 100644 --- a/lib/httpx/plugins/compression/gzip.rb +++ b/lib/httpx/plugins/compression/gzip.rb @@ -19,7 +19,7 @@ module HTTPX @compressed_chunk = "".b end - def deflate(raw, buffer, chunk_size:) + def deflate(raw, buffer = "".b, chunk_size: 16_384) gzip = Zlib::GzipWriter.new(self) begin @@ -38,6 +38,7 @@ module HTTPX buffer << compressed yield compressed if block_given? + buffer end private diff --git a/lib/httpx/plugins/grpc.rb b/lib/httpx/plugins/grpc.rb index db285dd0..d1dffd5f 100644 --- a/lib/httpx/plugins/grpc.rb +++ b/lib/httpx/plugins/grpc.rb @@ -35,9 +35,7 @@ module HTTPX DEADLINE = 60 HEADERS = { - # "accept-encoding" => "identity", - "grpc-accept-encoding" => "identity", - "content-type" => "application/grpc+proto", + "content-type" => "application/grpc", "grpc-timeout" => "#{DEADLINE}S", "te" => "trailers", "accept" => "application/grpc", @@ -54,7 +52,7 @@ module HTTPX # decodes a unary grpc response def unary(response) verify_status(response) - decode(response.to_s) + decode(response.to_s, encodings: response.headers.get("grpc-encoding"), encoders: response.encoders) end # lazy decodes a grpc stream response @@ -62,19 +60,34 @@ module HTTPX return enum_for(__method__, response) unless block_given? response.each do |frame| - yield decode(frame) + yield decode(frame, encodings: response.headers.get("grpc-encoding"), encoders: response.encoders) end end # encodes a single grpc message - def encode(bytes) - "".b << [0, bytes.bytesize].pack("CL>") << bytes + def encode(bytes, deflater:) + if deflater + compressed_flag = 1 + bytes = deflater.deflate(StringIO.new(bytes)) + else + compressed_flag = 0 + end + + "".b << [compressed_flag, bytes.bytesize].pack("CL>") << bytes.to_s end # decodes a single grpc message - def decode(message) - _compressed, size = message.unpack("CL>") - message.byteslice(5..size + 5 - 1) + def decode(message, encodings:, encoders:) + compressed, size = message.unpack("CL>") + data = message.byteslice(5..size + 5 - 1) + if compressed == 1 + encodings.reverse_each do |algo| + inflater = encoders.registry(algo).inflater(size) + data = inflater.inflate(data) + size = data.bytesize + end + end + data end # interprets the grpc call trailing metadata, and raises an @@ -125,6 +138,10 @@ module HTTPX @trailing_metadata = Hash[trailers] super end + + def encoders + @options.encodings + end end module InstanceMethods @@ -184,7 +201,10 @@ module HTTPX rpc_method = "/#{@options.grpc_service}#{rpc_method}" if @options.grpc_service uri.path = rpc_method - headers = HEADERS + headers = HEADERS.merge( + "grpc-accept-encoding" => ["identity", *@options.encodings.registry.keys], + ) + headers = headers.merge(metadata) if metadata body = if input.respond_to?(:each) diff --git a/sig/plugins/compression.rbs b/sig/plugins/compression.rbs index 294655da..148d22c1 100644 --- a/sig/plugins/compression.rbs +++ b/sig/plugins/compression.rbs @@ -6,8 +6,8 @@ module HTTPX type deflatable = _Reader | _ToS interface _Deflater - def deflate: (deflatable, _Writer, chunk_size: Integer) -> void - | (deflatable, _Writer, chunk_size: Integer) { (String) -> void } -> void + def deflate: (deflatable, ?_Writer, ?chunk_size: Integer) -> _ToS + | (deflatable, ?_Writer, ?chunk_size: Integer) { (String) -> void } -> _ToS end interface _Inflater diff --git a/test/support/requests/plugins/grpc.rb b/test/support/requests/plugins/grpc.rb index 31bef301..5f674c0b 100644 --- a/test/support/requests/plugins/grpc.rb +++ b/test/support/requests/plugins/grpc.rb @@ -14,21 +14,32 @@ module Requests assert call.metadata["k2"] == "v2" end - # stub = ::GRPC::ClientStub.new("localhost:#{server_port}", - # :this_channel_is_insecure) grpc = HTTPX.plugin(:grpc) # build service stub = grpc.build_stub("http://localhost:#{server_port}") result = stub.execute("an_rpc_method", "a_request", metadata: { k1: "v1", k2: "v2" }) - # stub = ::GRPC::ClientStub.new("localhost:#{server_port}", :this_channel_is_insecure) - # op = stub.request_response("an_rpc_method", "a_request", no_marshal, no_marshal, - # return_op: true, metadata: { k1: "v1", k2: "v2" }) - # op.start_call if run_start_call_first - # result = op.execute assert result == "a_reply" end + def test_plugin_grpc_compressed_response + no_marshal = proc { |x| x } + + server_port = run_request_response("A" * 2000, OK, marshal: no_marshal, + server_initial_md: { "grpc-internal-encoding-request" => "gzip" }) do |call| + assert call.remote_read == "a_request" + # assert call.metadata["k1"] == "v1" + # assert call.metadata["k2"] == "v2" + end + + grpc = HTTPX.plugin(:grpc) + # build service + stub = grpc.build_stub("http://localhost:#{server_port}") + result = stub.execute("an_rpc_method", "a_request") + + assert result == "A" * 2000 + end + def test_plugin_grpc_unary_protobuf server_port = run_rpc(TestService)