diff --git a/go.mod b/go.mod index 04b9346ab0..6d1da3695e 100644 --- a/go.mod +++ b/go.mod @@ -135,6 +135,7 @@ require ( github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/soheilhy/cmux v0.1.5 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/supranational/blst v0.3.13 // indirect github.com/tidwall/btree v1.6.0 // indirect diff --git a/go.sum b/go.sum index 0df54ea1d0..272af6adfe 100644 --- a/go.sum +++ b/go.sum @@ -1222,6 +1222,8 @@ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= @@ -1492,6 +1494,7 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210220033124-5f55cee0dc0d/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= diff --git a/protocol/lavasession/consumer_types.go b/protocol/lavasession/consumer_types.go index 644e8b330d..d1e88a3494 100644 --- a/protocol/lavasession/consumer_types.go +++ b/protocol/lavasession/consumer_types.go @@ -41,7 +41,7 @@ func (list EndpointInfoList) Swap(i, j int) { const ( AllowInsecureConnectionToProvidersFlag = "allow-insecure-provider-dialing" - AllowGRPCCompressionFlag = "enable-application-level-compression" + AllowGRPCCompressionFlag = "enable-grpc-compression" MaximumStreamsOverASingleConnectionFlag = "maximum-streams-per-connection" DefaultMaximumStreamsOverASingleConnection = 100 WeightMultiplierForStaticProviders = 10 diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 7a786bc8e5..348846f48d 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -485,7 +485,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 } lavasession.AllowGRPCCompressionForConsumerProviderCommunication = viper.GetBool(lavasession.AllowGRPCCompressionFlag) if lavasession.AllowGRPCCompressionForConsumerProviderCommunication { - utils.LavaFormatInfo("AllowGRPCCompressionForConsumerProviderCommunication is set to true, messages will be compressed", utils.Attribute{Key: lavasession.AllowGRPCCompressionFlag, Value: lavasession.AllowGRPCCompressionForConsumerProviderCommunication}) + utils.LavaFormatInfo("gRPC compression enabled, relay messages will use gzip compression", utils.Attribute{Key: lavasession.AllowGRPCCompressionFlag, Value: lavasession.AllowGRPCCompressionForConsumerProviderCommunication}) } var rpcEndpoints []*lavasession.RPCEndpoint @@ -689,7 +689,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().Uint(common.MaximumConcurrentProvidersFlagName, 3, "max number of concurrent providers to communicate with") cmdRPCConsumer.MarkFlagRequired(common.GeolocationFlag) cmdRPCConsumer.Flags().Bool(lavasession.AllowInsecureConnectionToProvidersFlag, false, "allow insecure provider-dialing. used for development and testing") - cmdRPCConsumer.Flags().Bool(lavasession.AllowGRPCCompressionFlag, false, "allow messages to be compressed when communicating between the consumer and provider") + cmdRPCConsumer.Flags().Bool(lavasession.AllowGRPCCompressionFlag, false, "enable gzip compression for gRPC messages between consumer and provider (reduces bandwidth, adds CPU overhead)") cmdRPCConsumer.Flags().Uint64Var(&lavasession.MaximumStreamsOverASingleConnection, lavasession.MaximumStreamsOverASingleConnectionFlag, lavasession.DefaultMaximumStreamsOverASingleConnection, "maximum number of parallel streams over a single provider connection") cmdRPCConsumer.Flags().Bool(common.TestModeFlagName, false, "test mode causes rpcconsumer to send dummy data and print all of the metadata in it's listeners") cmdRPCConsumer.Flags().String(performance.PprofAddressFlagName, "", "pprof server address, used for code profiling") diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index 324c7b5c02..01f944603e 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -1413,11 +1413,9 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe common.LAVA_LB_UNIQUE_ID_HEADER: singleConsumerSession.EndpointConnection.GetLbUniqueId(), }) - // Add custom header to indicate compression support if flag is enabled - compressionEnabled := lavasession.AllowGRPCCompressionForConsumerProviderCommunication - if compressionEnabled { - metadataAdd.Set(common.LavaCompressionSupportHeader, "true") - } + // Note: gRPC compression is handled automatically by the gRPC layer + // when AllowGRPCCompressionForConsumerProviderCommunication is enabled + // via grpc.UseCompressor(gzip.Name) in ConnectGRPCClient utils.LavaFormatTrace("Sending relay to provider", utils.LogAttr("GUID", ctx), @@ -1443,23 +1441,7 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe reply, err = endpointClient.Relay(connectCtx, relayRequest, grpc.Header(&responseHeader), grpc.Trailer(&relayResult.ProviderTrailer)) relayLatency = time.Since(relaySentTime) - // Check if response is compressed and decompress if needed - appLevelCompressed := false - if lavaCompressionValues := responseHeader.Get(common.LavaCompressionHeader); len(lavaCompressionValues) > 0 { - appLevelCompressed = lavaCompressionValues[0] == common.LavaCompressionGzip - } - - if reply != nil && reply.Data != nil && appLevelCompressed { - decompressedData, decompressErr := common.DecompressData(reply.Data) - if decompressErr != nil { - utils.LavaFormatError("Failed to decompress response", decompressErr, - utils.LogAttr("GUID", ctx), - utils.LogAttr("providerName", providerPublicAddress), - ) - return nil, 0, decompressErr, false - } - reply.Data = decompressedData - } + // Note: gRPC decompression is handled automatically by the gRPC layer providerUniqueId := relayResult.ProviderTrailer.Get(chainlib.RpcProviderUniqueIdHeader) if len(providerUniqueId) > 0 { diff --git a/protocol/rpcprovider/grpc_compression_test.go b/protocol/rpcprovider/grpc_compression_test.go new file mode 100644 index 0000000000..c3c6964334 --- /dev/null +++ b/protocol/rpcprovider/grpc_compression_test.go @@ -0,0 +1,545 @@ +package rpcprovider + +import ( + "context" + "fmt" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/lavanet/lava/v5/protocol/lavasession" + pairingtypes "github.com/lavanet/lava/v5/x/pairing/types" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/encoding/gzip" + "google.golang.org/grpc/stats" +) + +// createTestEndpoint creates an RPCProviderEndpoint for testing +func createTestEndpoint(addr string) *lavasession.RPCProviderEndpoint { + return &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } +} + +// compressionStatsHandler tracks gRPC compression statistics +type compressionStatsHandler struct { + outPayloadCompressedBytes atomic.Int64 + outPayloadUncompressedBytes atomic.Int64 + inPayloadCompressedBytes atomic.Int64 + inPayloadUncompressedBytes atomic.Int64 +} + +func (h *compressionStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return ctx +} + +func (h *compressionStatsHandler) HandleRPC(ctx context.Context, s stats.RPCStats) { + switch st := s.(type) { + case *stats.OutPayload: + h.outPayloadCompressedBytes.Add(int64(st.CompressedLength)) + h.outPayloadUncompressedBytes.Add(int64(st.Length)) + case *stats.InPayload: + h.inPayloadCompressedBytes.Add(int64(st.CompressedLength)) + h.inPayloadUncompressedBytes.Add(int64(st.Length)) + } +} + +func (h *compressionStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + return ctx +} + +func (h *compressionStatsHandler) HandleConn(ctx context.Context, s stats.ConnStats) {} + +func (h *compressionStatsHandler) GetCompressionRatio() float64 { + uncompressed := h.inPayloadUncompressedBytes.Load() + compressed := h.inPayloadCompressedBytes.Load() + if uncompressed == 0 { + return 1.0 + } + return float64(compressed) / float64(uncompressed) +} + +func (h *compressionStatsHandler) Reset() { + h.outPayloadCompressedBytes.Store(0) + h.outPayloadUncompressedBytes.Store(0) + h.inPayloadCompressedBytes.Store(0) + h.inPayloadUncompressedBytes.Store(0) +} + +// TestGRPCCompressionEnabled verifies that gRPC compression actually compresses data +func TestGRPCCompressionEnabled(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create provider listener + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + // Large, highly compressible response data (JSON-like repeated pattern) + largeResponseData := []byte(strings.Repeat(`{"blockNumber":"0x123456","result":"success","data":"`, 1000) + + strings.Repeat("a]", 5000) + `"}`) + + // Register mock receiver that returns large compressible data + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{Data: largeResponseData}, nil + }, + } + err := pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Create stats handler to track compression + statsHandler := &compressionStatsHandler{} + + // Connect WITH compression enabled + conn, err := grpc.DialContext(ctx, addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)), + grpc.WithStatsHandler(statsHandler), + ) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + // Make relay request + reply, err := client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: []byte("test request"), + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + require.NoError(t, err) + require.NotNil(t, reply) + require.Equal(t, largeResponseData, reply.Data) + + // Verify compression occurred - compressed should be smaller than uncompressed + inCompressed := statsHandler.inPayloadCompressedBytes.Load() + inUncompressed := statsHandler.inPayloadUncompressedBytes.Load() + + t.Logf("Response size - Compressed: %d bytes, Uncompressed: %d bytes, Ratio: %.2f%%", + inCompressed, inUncompressed, float64(inCompressed)/float64(inUncompressed)*100) + + // For highly compressible data, we expect significant compression (at least 50% reduction) + require.Greater(t, inUncompressed, inCompressed, + "Compressed payload should be smaller than uncompressed") + require.Less(t, float64(inCompressed)/float64(inUncompressed), 0.5, + "Expected at least 50%% compression for repetitive JSON data") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +// TestGRPCCompressionDisabled verifies that without compression flag, data is not compressed +func TestGRPCCompressionDisabled(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + // Same large response data + largeResponseData := []byte(strings.Repeat(`{"blockNumber":"0x123456","result":"success"}`, 1000)) + + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{Data: largeResponseData}, nil + }, + } + err := pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + statsHandler := &compressionStatsHandler{} + + // Connect WITHOUT compression (no grpc.UseCompressor) + conn, err := grpc.DialContext(ctx, addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithStatsHandler(statsHandler), + ) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + reply, err := client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: []byte("test request"), + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + require.NoError(t, err) + require.NotNil(t, reply) + + // Without compression, compressed == uncompressed (or very close due to protobuf overhead) + inCompressed := statsHandler.inPayloadCompressedBytes.Load() + inUncompressed := statsHandler.inPayloadUncompressedBytes.Load() + + t.Logf("Response size (no compression) - Wire: %d bytes, Logical: %d bytes", + inCompressed, inUncompressed) + + // Without compression enabled, sizes should be approximately equal + require.InDelta(t, inCompressed, inUncompressed, float64(inUncompressed)*0.05, + "Without compression, wire size should equal logical size (within 5%%)") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +// TestGRPCCompressionBidirectional verifies compression works in both directions +func TestGRPCCompressionBidirectional(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + // Large response + largeResponseData := []byte(strings.Repeat("response_data_", 5000)) + + var receivedRequestSize int + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + receivedRequestSize = len(request.RelayData.Data) + return &pairingtypes.RelayReply{Data: largeResponseData}, nil + }, + } + err := pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + statsHandler := &compressionStatsHandler{} + + conn, err := grpc.DialContext(ctx, addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)), + grpc.WithStatsHandler(statsHandler), + ) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + // Large, compressible request + largeRequestData := []byte(strings.Repeat("request_data_", 5000)) + + reply, err := client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: largeRequestData, + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + require.NoError(t, err) + require.NotNil(t, reply) + + // Verify provider received the full uncompressed request + require.Equal(t, len(largeRequestData), receivedRequestSize, + "Provider should receive uncompressed request data") + + // Verify response was received correctly + require.Equal(t, largeResponseData, reply.Data) + + // Check outgoing (request) compression + outCompressed := statsHandler.outPayloadCompressedBytes.Load() + outUncompressed := statsHandler.outPayloadUncompressedBytes.Load() + + t.Logf("Request - Compressed: %d bytes, Uncompressed: %d bytes, Ratio: %.2f%%", + outCompressed, outUncompressed, float64(outCompressed)/float64(outUncompressed)*100) + + require.Greater(t, outUncompressed, outCompressed, + "Outgoing request should be compressed") + + // Check incoming (response) compression + inCompressed := statsHandler.inPayloadCompressedBytes.Load() + inUncompressed := statsHandler.inPayloadUncompressedBytes.Load() + + t.Logf("Response - Compressed: %d bytes, Uncompressed: %d bytes, Ratio: %.2f%%", + inCompressed, inUncompressed, float64(inCompressed)/float64(inUncompressed)*100) + + require.Greater(t, inUncompressed, inCompressed, + "Incoming response should be compressed") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +// TestGRPCCompressionWithConnectGRPCClient tests compression via the actual ConnectGRPCClient function +func TestGRPCCompressionWithConnectGRPCClient(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + largeResponseData := []byte(strings.Repeat(`{"result":"ok","data":"test"}`, 2000)) + + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{Data: largeResponseData}, nil + }, + } + err := pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Test with compression enabled using the actual lavasession.ConnectGRPCClient + connCtx, connCancel := context.WithTimeout(ctx, 5*time.Second) + defer connCancel() + + // allowInsecure=true, skipTLS=true, allowCompression=true + conn, err := lavasession.ConnectGRPCClient(connCtx, addr, true, true, true) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + reply, err := client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: []byte("test"), + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + require.NoError(t, err) + require.NotNil(t, reply) + require.Equal(t, largeResponseData, reply.Data, + "Data should be correctly transmitted with compression enabled via ConnectGRPCClient") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +// TestGRPCCompressionSmallPayload verifies behavior with small payloads +func TestGRPCCompressionSmallPayload(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + // Small response - compression may not help much + smallResponseData := []byte(`{"ok":true}`) + + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{Data: smallResponseData}, nil + }, + } + err := pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + statsHandler := &compressionStatsHandler{} + + conn, err := grpc.DialContext(ctx, addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)), + grpc.WithStatsHandler(statsHandler), + ) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + reply, err := client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: []byte("x"), + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + require.NoError(t, err) + require.NotNil(t, reply) + require.Equal(t, smallResponseData, reply.Data) + + // For small payloads, compression still works but may have overhead + // Just verify data integrity - compression ratio may not be favorable + t.Logf("Small payload - Compressed: %d bytes, Uncompressed: %d bytes", + statsHandler.inPayloadCompressedBytes.Load(), + statsHandler.inPayloadUncompressedBytes.Load()) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +// BenchmarkGRPCWithCompression benchmarks relay with compression +func BenchmarkGRPCWithCompression(b *testing.B) { + addr := getAvailablePortForBenchmark(b) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + + largeResponseData := []byte(strings.Repeat(`{"result":"benchmark_data"}`, 1000)) + + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{Data: largeResponseData}, nil + }, + } + _ = pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + + time.Sleep(100 * time.Millisecond) + + conn, _ := grpc.DialContext(ctx, addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)), + ) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: []byte("bench"), + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +// BenchmarkGRPCWithoutCompression benchmarks relay without compression +func BenchmarkGRPCWithoutCompression(b *testing.B) { + addr := getAvailablePortForBenchmark(b) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + + largeResponseData := []byte(strings.Repeat(`{"result":"benchmark_data"}`, 1000)) + + receiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{Data: largeResponseData}, nil + }, + } + _ = pl.RegisterReceiver(receiver, createTestEndpoint(addr)) + + time.Sleep(100 * time.Millisecond) + + conn, _ := grpc.DialContext(ctx, addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + // No compression + ) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: []byte("bench"), + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func getAvailablePortForBenchmark(b *testing.B) string { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatal(err) + } + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + b.Fatal("expected TCP address") + } + port := tcpAddr.Port + listener.Close() + return fmt.Sprintf("127.0.0.1:%d", port) +} diff --git a/protocol/rpcprovider/provider_listener.go b/protocol/rpcprovider/provider_listener.go index 0fd653856b..db83453fff 100644 --- a/protocol/rpcprovider/provider_listener.go +++ b/protocol/rpcprovider/provider_listener.go @@ -2,35 +2,56 @@ package rpcprovider import ( "context" - "errors" - "fmt" + "crypto/tls" + "net" "net/http" "strings" "sync" "github.com/gogo/status" - "github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/lavanet/lava/v5/protocol/chainlib" - "github.com/lavanet/lava/v5/protocol/common" "github.com/lavanet/lava/v5/protocol/lavaprotocol/protocolerrors" "github.com/lavanet/lava/v5/protocol/lavasession" "github.com/lavanet/lava/v5/utils" pairingtypes "github.com/lavanet/lava/v5/x/pairing/types" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" + "github.com/soheilhy/cmux" grpc "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding/gzip" // Register gzip compressor + "google.golang.org/grpc/health" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" ) +// Ensure gzip compressor is registered +var _ = gzip.Name + const ( HealthCheckURLPathFlagName = "health-check-url-path" HealthCheckURLPathFlagDefault = "/lava/health" ) +// isShutdownError returns true if the error is expected during graceful shutdown +func isShutdownError(err error) bool { + if err == nil { + return false + } + if err == http.ErrServerClosed || err == cmux.ErrListenerClosed || err == net.ErrClosed { + return true + } + // Check error message for common shutdown patterns + errStr := err.Error() + return strings.Contains(errStr, "use of closed network connection") || + strings.Contains(errStr, "server closed") || + strings.Contains(errStr, "mux: listener closed") +} + type ProviderListener struct { networkAddress string relayServer *relayServer - httpServer http.Server + grpcServer *grpc.Server + httpServer *http.Server + healthServer *health.Server + cmux cmux.CMux } func (pl *ProviderListener) Key() string { @@ -47,69 +68,107 @@ func (pl *ProviderListener) RegisterReceiver(existingReceiver RelayReceiver, end return utils.LavaFormatError("double_receiver_setup receiver already defined on this address with the same chainID and apiInterface", nil, utils.Attribute{Key: "chainID", Value: endpoint.ChainID}, utils.Attribute{Key: "apiInterface", Value: endpoint.ApiInterface}) } pl.relayServer.relayReceivers[listen_endpoint.Key()] = &relayReceiverWrapper{relayReceiver: &existingReceiver, enabled: true} + // Mark service as healthy when receiver is registered + serviceName := endpoint.ChainID + "-" + endpoint.ApiInterface + pl.healthServer.SetServingStatus(serviceName, healthgrpc.HealthCheckResponse_SERVING) utils.LavaFormatInfo("[++] Provider Listening on Address", utils.Attribute{Key: "chainID", Value: endpoint.ChainID}, utils.Attribute{Key: "apiInterface", Value: endpoint.ApiInterface}, utils.Attribute{Key: "Address", Value: endpoint.NetworkAddress}) return nil } func (pl *ProviderListener) Shutdown(shutdownCtx context.Context) error { - if err := pl.httpServer.Shutdown(shutdownCtx); err != nil { - utils.LavaFormatFatal("Provider failed to shutdown", err) - } + pl.healthServer.Shutdown() + pl.httpServer.Shutdown(shutdownCtx) + pl.grpcServer.GracefulStop() return nil } func NewProviderListener(ctx context.Context, networkAddress lavasession.NetworkAddressData, healthCheckPath string) *ProviderListener { pl := &ProviderListener{networkAddress: networkAddress.Address} - // GRPC lis := chainlib.GetListenerWithRetryGrpc("tcp", networkAddress.Address) - opts := []grpc.ServerOption{ - grpc.MaxRecvMsgSize(1024 * 1024 * 512), // setting receive size to 512mb for large debug responses - grpc.MaxSendMsgSize(1024 * 1024 * 512), // setting send size to 512mb for large debug responses - } - grpcServer := grpc.NewServer(opts...) - wrappedServer := grpcweb.WrapServer(grpcServer) - handler := func(resp http.ResponseWriter, req *http.Request) { - // Set CORS headers - resp.Header().Set("Access-Control-Allow-Origin", "*") - resp.Header().Set("Access-Control-Allow-Headers", fmt.Sprintf("Content-Type, x-grpc-web, %s", common.LAVA_CONSUMER_PROCESS_GUID)) - - if req.URL.Path == healthCheckPath && req.Method == http.MethodGet { - resp.WriteHeader(http.StatusOK) - resp.Write([]byte("Healthy")) - return - } - wrappedServer.ServeHTTP(resp, req) + // Wrap with TLS if enabled + if !networkAddress.DisableTLS { + tlsConfig := lavasession.GetTlsConfig(networkAddress) + lis = tls.NewListener(lis, tlsConfig) + } else { + utils.LavaFormatInfo("Running with disabled TLS configuration") } - pl.httpServer = http.Server{ - Handler: h2c.NewHandler(http.HandlerFunc(handler), &http2.Server{}), - } + // Create connection multiplexer to handle both HTTP and gRPC on same port + mux := cmux.New(lis) + // Match HTTP/1.1 first for health checks (fast prefix match) + httpListener := mux.Match(cmux.HTTP1Fast()) + // Everything else goes to gRPC (avoids expensive header parsing) + grpcListener := mux.Match(cmux.Any()) + pl.cmux = mux - var serveExecutor func() error - if networkAddress.DisableTLS { - utils.LavaFormatInfo("Running with disabled TLS configuration") - serveExecutor = func() error { - return pl.httpServer.Serve(lis) - } - } else { - pl.httpServer.TLSConfig = lavasession.GetTlsConfig(networkAddress) - serveExecutor = func() error { - return pl.httpServer.ServeTLS(lis, "", "") - } + // Build gRPC server + opts := []grpc.ServerOption{ + grpc.MaxRecvMsgSize(1024 * 1024 * 512), // 512MB for large debug responses + grpc.MaxSendMsgSize(1024 * 1024 * 512), // 512MB for large debug responses } + grpcServer := grpc.NewServer(opts...) + pl.grpcServer = grpcServer + + // Register gRPC health checking service + healthServer := health.NewServer() + healthgrpc.RegisterHealthServer(grpcServer, healthServer) + healthServer.SetServingStatus("", healthgrpc.HealthCheckResponse_SERVING) + pl.healthServer = healthServer + // Register relay server relayServer := &relayServer{relayReceivers: map[string]*relayReceiverWrapper{}} pl.relayServer = relayServer pairingtypes.RegisterRelayerServer(grpcServer, relayServer) + + // Create HTTP server for health checks + httpMux := http.NewServeMux() + // Only register health check handler if path is provided + if healthCheckPath != "" { + httpMux.HandleFunc(healthCheckPath, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Healthy")) + } else { + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + } + httpServer := &http.Server{Handler: httpMux} + pl.httpServer = httpServer + + // Start servers go func() { utils.LavaFormatInfo("New provider listener active", utils.Attribute{Key: "address", Value: networkAddress}) - if err := serveExecutor(); !errors.Is(err, http.ErrServerClosed) { - utils.LavaFormatFatal("provider failed to serve", err, utils.Attribute{Key: "Address", Value: lis.Addr().String()}) + if err := grpcServer.Serve(grpcListener); err != nil { + // Ignore expected shutdown errors + if isShutdownError(err) { + return + } + utils.LavaFormatFatal("gRPC server failed", err, utils.Attribute{Key: "Address", Value: networkAddress.Address}) + } + }() + + go func() { + if err := httpServer.Serve(httpListener); err != nil { + // Ignore expected shutdown errors + if isShutdownError(err) { + return + } + utils.LavaFormatFatal("HTTP health server failed", err, utils.Attribute{Key: "Address", Value: networkAddress.Address}) + } + }() + + go func() { + if err := mux.Serve(); err != nil { + if !isShutdownError(err) { + utils.LavaFormatError("cmux serve error", err) + } } utils.LavaFormatInfo("listener closed server", utils.Attribute{Key: "address", Value: networkAddress}) }() + return pl } diff --git a/protocol/rpcprovider/provider_listener_test.go b/protocol/rpcprovider/provider_listener_test.go new file mode 100644 index 0000000000..330e9a3db3 --- /dev/null +++ b/protocol/rpcprovider/provider_listener_test.go @@ -0,0 +1,750 @@ +package rpcprovider + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/lavanet/lava/v5/protocol/lavasession" + pairingtypes "github.com/lavanet/lava/v5/x/pairing/types" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" +) + +// getAvailablePort finds an available port for testing +func getAvailablePort(t *testing.T) string { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok, "expected TCP address") + port := tcpAddr.Port + listener.Close() + return fmt.Sprintf("127.0.0.1:%d", port) +} + +// mockRelayReceiver implements RelayReceiver for testing +type mockRelayReceiver struct { + relayFunc func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) + probeFunc func(ctx context.Context, probeReq *pairingtypes.ProbeRequest) (*pairingtypes.ProbeReply, error) + subscribeFunc func(request *pairingtypes.RelayRequest, srv pairingtypes.Relayer_RelaySubscribeServer) error +} + +func (m *mockRelayReceiver) Relay(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + if m.relayFunc != nil { + return m.relayFunc(ctx, request) + } + return &pairingtypes.RelayReply{Data: []byte("mock relay response")}, nil +} + +func (m *mockRelayReceiver) Probe(ctx context.Context, probeReq *pairingtypes.ProbeRequest) (*pairingtypes.ProbeReply, error) { + if m.probeFunc != nil { + return m.probeFunc(ctx, probeReq) + } + return &pairingtypes.ProbeReply{Guid: probeReq.Guid, LatestBlock: 12345}, nil +} + +func (m *mockRelayReceiver) RelaySubscribe(request *pairingtypes.RelayRequest, srv pairingtypes.Relayer_RelaySubscribeServer) error { + if m.subscribeFunc != nil { + return m.subscribeFunc(request, srv) + } + return nil +} + +func TestNewProviderListener(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + require.Equal(t, addr, pl.Key()) + require.NotNil(t, pl.grpcServer) + require.NotNil(t, pl.httpServer) + require.NotNil(t, pl.healthServer) + require.NotNil(t, pl.relayServer) + + // Give time for listeners to start + time.Sleep(100 * time.Millisecond) + + // Shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + err := pl.Shutdown(shutdownCtx) + require.NoError(t, err) +} + +func TestProviderListener_HTTPHealthCheck(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + healthPath := "/lava/health" + pl := NewProviderListener(ctx, networkAddr, healthPath) + require.NotNil(t, pl) + + // Give time for listeners to start + time.Sleep(100 * time.Millisecond) + + // Test HTTP health check + resp, err := http.Get(fmt.Sprintf("http://%s%s", addr, healthPath)) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "Healthy", string(body)) + + // Shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + err = pl.Shutdown(shutdownCtx) + require.NoError(t, err) +} + +func TestProviderListener_HTTPHealthCheckMethodNotAllowed(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + healthPath := "/lava/health" + pl := NewProviderListener(ctx, networkAddr, healthPath) + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Test POST to health check (should return 405) + resp, err := http.Post(fmt.Sprintf("http://%s%s", addr, healthPath), "text/plain", nil) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_GRPCHealthService(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Connect via gRPC + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + // Check gRPC health service + healthClient := healthgrpc.NewHealthClient(conn) + resp, err := healthClient.Check(ctx, &healthgrpc.HealthCheckRequest{Service: ""}) + require.NoError(t, err) + require.Equal(t, healthgrpc.HealthCheckResponse_SERVING, resp.Status) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_RegisterReceiver(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Create mock receiver + mockReceiver := &mockRelayReceiver{} + + endpoint := &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } + + // Register receiver + err := pl.RegisterReceiver(mockReceiver, endpoint) + require.NoError(t, err) + + // Verify health status is set for the service + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + healthClient := healthgrpc.NewHealthClient(conn) + resp, err := healthClient.Check(ctx, &healthgrpc.HealthCheckRequest{Service: "ETH1-jsonrpc"}) + require.NoError(t, err) + require.Equal(t, healthgrpc.HealthCheckResponse_SERVING, resp.Status) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_RegisterReceiverDuplicate(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + mockReceiver := &mockRelayReceiver{} + endpoint := &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } + + // First registration should succeed + err := pl.RegisterReceiver(mockReceiver, endpoint) + require.NoError(t, err) + + // Second registration with same endpoint should fail + err = pl.RegisterReceiver(mockReceiver, endpoint) + require.Error(t, err) + require.Contains(t, err.Error(), "double_receiver_setup") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_GRPCProbe(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Register a mock receiver + expectedLatestBlock := int64(99999) + mockReceiver := &mockRelayReceiver{ + probeFunc: func(ctx context.Context, probeReq *pairingtypes.ProbeRequest) (*pairingtypes.ProbeReply, error) { + return &pairingtypes.ProbeReply{ + Guid: probeReq.Guid, + LatestBlock: expectedLatestBlock, + }, nil + }, + } + + endpoint := &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } + + err := pl.RegisterReceiver(mockReceiver, endpoint) + require.NoError(t, err) + + // Connect via gRPC and call Probe + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + probeResp, err := client.Probe(ctx, &pairingtypes.ProbeRequest{ + Guid: 12345, + SpecId: "ETH1", + ApiInterface: "jsonrpc", + }) + require.NoError(t, err) + require.Equal(t, uint64(12345), probeResp.Guid) + require.Equal(t, expectedLatestBlock, probeResp.LatestBlock) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_GRPCRelay(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Register a mock receiver + expectedData := []byte("test relay response data") + mockReceiver := &mockRelayReceiver{ + relayFunc: func(ctx context.Context, request *pairingtypes.RelayRequest) (*pairingtypes.RelayReply, error) { + return &pairingtypes.RelayReply{ + Data: expectedData, + LatestBlock: 100, + }, nil + }, + } + + endpoint := &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } + + err := pl.RegisterReceiver(mockReceiver, endpoint) + require.NoError(t, err) + + // Connect via gRPC and call Relay + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + relayResp, err := client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + ApiInterface: "jsonrpc", + }, + RelaySession: &pairingtypes.RelaySession{ + SpecId: "ETH1", + }, + }) + require.NoError(t, err) + require.Equal(t, expectedData, relayResp.Data) + require.Equal(t, int64(100), relayResp.LatestBlock) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_UnhandledReceiver(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Don't register any receiver, just try to call + + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + // Try to probe unregistered chain + _, err = client.Probe(ctx, &pairingtypes.ProbeRequest{ + Guid: 12345, + SpecId: "UNKNOWN", + ApiInterface: "jsonrpc", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "unhandled relay receiver") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_RelayInvalidRequest(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + // Call with nil RelayData + _, err = client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: nil, + RelaySession: &pairingtypes.RelaySession{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "internal fields are nil") + + // Call with nil RelaySession + _, err = client.Relay(ctx, &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{}, + RelaySession: nil, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "internal fields are nil") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_ConcurrentRequests(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Register receiver + mockReceiver := &mockRelayReceiver{} + endpoint := &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } + err := pl.RegisterReceiver(mockReceiver, endpoint) + require.NoError(t, err) + + // Make concurrent gRPC and HTTP requests + const numRequests = 50 + errChan := make(chan error, numRequests*2) + + // gRPC requests + for i := 0; i < numRequests; i++ { + go func() { + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + errChan <- err + return + } + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + _, err = client.Probe(ctx, &pairingtypes.ProbeRequest{ + Guid: 12345, + SpecId: "ETH1", + ApiInterface: "jsonrpc", + }) + errChan <- err + }() + } + + // HTTP health check requests + for i := 0; i < numRequests; i++ { + go func() { + resp, err := http.Get(fmt.Sprintf("http://%s/lava/health", addr)) + if err != nil { + errChan <- err + return + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + errChan <- fmt.Errorf("unexpected status: %d", resp.StatusCode) + return + } + errChan <- nil + }() + } + + // Collect results + successCount := 0 + for i := 0; i < numRequests*2; i++ { + err := <-errChan + if err == nil { + successCount++ + } + } + + require.Equal(t, numRequests*2, successCount, "All concurrent requests should succeed") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_CustomHealthPath(t *testing.T) { + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + customPath := "/custom/health/check" + pl := NewProviderListener(ctx, networkAddr, customPath) + require.NotNil(t, pl) + + time.Sleep(100 * time.Millisecond) + + // Custom path should work + resp, err := http.Get(fmt.Sprintf("http://%s%s", addr, customPath)) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Default path should not work (404) + resp2, err := http.Get(fmt.Sprintf("http://%s/lava/health", addr)) + require.NoError(t, err) + defer resp2.Body.Close() + require.Equal(t, http.StatusNotFound, resp2.StatusCode) + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestProviderListener_TLSConfiguration(t *testing.T) { + // This test verifies TLS configuration is applied correctly + // We test with DisableTLS=false but without actual certs, so we expect it to fail to connect + // This validates the TLS path is taken + + addr := getAvailablePort(t) + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: false, // TLS enabled, but no certs + KeyPem: "", + CertPem: "", + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Note: This may fail during NewProviderListener if GetTlsConfig returns nil + // or succeed and then fail on connection. Either way, we're testing TLS path. + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + if pl == nil { + // Expected if TLS config fails + return + } + + time.Sleep(100 * time.Millisecond) + + // HTTP without TLS should fail + _, err := http.Get(fmt.Sprintf("http://%s/lava/health", addr)) + // Either connection refused or TLS handshake error is expected + if err == nil { + t.Log("HTTP connection succeeded - TLS might not be enforced") + } + + // HTTPS should work if we skip verification (for testing) + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + resp, err := client.Get(fmt.Sprintf("https://%s/lava/health", addr)) + if err == nil { + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func TestRelayServer_FindReceiver(t *testing.T) { + rs := &relayServer{relayReceivers: map[string]*relayReceiverWrapper{}} + + var mockReceiver RelayReceiver = &mockRelayReceiver{} + + // Add receiver + endpoint := lavasession.RPCEndpoint{ChainID: "ETH1", ApiInterface: "jsonrpc"} + rs.relayReceivers[endpoint.Key()] = &relayReceiverWrapper{ + relayReceiver: &mockReceiver, + enabled: true, + } + + // Find existing receiver + found, err := rs.findReceiver("jsonrpc", "ETH1") + require.NoError(t, err) + require.NotNil(t, found) + + // Find non-existent receiver + _, err = rs.findReceiver("jsonrpc", "UNKNOWN") + require.Error(t, err) + require.Contains(t, err.Error(), "unhandled relay receiver") + + // Find disabled receiver + rs.relayReceivers[endpoint.Key()].enabled = false + _, err = rs.findReceiver("jsonrpc", "ETH1") + require.Error(t, err) + require.Contains(t, err.Error(), "disabled") +} + +func BenchmarkProviderListener_HTTPHealthCheck(b *testing.B) { + addr := fmt.Sprintf("127.0.0.1:%d", 30000+b.N%1000) + listener, err := net.Listen("tcp", addr) + if err != nil { + b.Skipf("Could not bind to port: %v", err) + } + listener.Close() + + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + if pl == nil { + b.Fatal("Failed to create provider listener") + } + + time.Sleep(100 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := http.Get(fmt.Sprintf("http://%s/lava/health", addr)) + if err != nil { + b.Fatal(err) + } + resp.Body.Close() + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} + +func BenchmarkProviderListener_GRPCProbe(b *testing.B) { + addr := fmt.Sprintf("127.0.0.1:%d", 31000+b.N%1000) + listener, err := net.Listen("tcp", addr) + if err != nil { + b.Skipf("Could not bind to port: %v", err) + } + listener.Close() + + networkAddr := lavasession.NetworkAddressData{ + Address: addr, + DisableTLS: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pl := NewProviderListener(ctx, networkAddr, "/lava/health") + if pl == nil { + b.Fatal("Failed to create provider listener") + } + + time.Sleep(100 * time.Millisecond) + + mockReceiver := &mockRelayReceiver{} + endpoint := &lavasession.RPCProviderEndpoint{ + ChainID: "ETH1", + ApiInterface: "jsonrpc", + NetworkAddress: lavasession.NetworkAddressData{ + Address: addr, + }, + } + pl.RegisterReceiver(mockReceiver, endpoint) + + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + client := pairingtypes.NewRelayerClient(conn) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := client.Probe(ctx, &pairingtypes.ProbeRequest{ + Guid: uint64(i), + SpecId: "ETH1", + ApiInterface: "jsonrpc", + }) + if err != nil { + b.Fatal(err) + } + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + pl.Shutdown(shutdownCtx) +} diff --git a/protocol/rpcprovider/rpcprovider_server.go b/protocol/rpcprovider/rpcprovider_server.go index 054f23b45f..f0d3a536cd 100644 --- a/protocol/rpcprovider/rpcprovider_server.go +++ b/protocol/rpcprovider/rpcprovider_server.go @@ -290,11 +290,6 @@ func (rpcps *RPCProviderServer) Relay(ctx context.Context, request *pairingtypes utils.Attribute{Key: "requestBlock", Value: request.RelayData.GetRequestBlock()}, ) - // Check if consumer supports compression via custom header - md, _ := metadata.FromIncomingContext(ctx) - compressionSupport := md.Get(common.LavaCompressionSupportHeader) - consumerSupportsCompression := len(compressionSupport) > 0 && compressionSupport[0] == "true" - // Init relay var relaySession *lavasession.SingleProviderSession var consumerAddress sdk.AccAddress @@ -430,28 +425,8 @@ func (rpcps *RPCProviderServer) Relay(ctx context.Context, request *pairingtypes rpcps.metrics.SetEndToEndLatency(endToEndLatencyMs) } - // Application-level compression if consumer supports it - if reply != nil && reply.Data != nil { - // Only compress if consumer explicitly supports it via custom header - // Don't rely on grpc-accept-encoding which is always present - if consumerSupportsCompression { - originalSize := len(reply.Data) - compressedData, wasCompressed, compressErr := common.CompressData(reply.Data, common.CompressionThreshold) - - if compressErr != nil { - utils.LavaFormatWarning("Failed to compress response, sending uncompressed", - compressErr, - utils.LogAttr("GUID", ctx), - utils.LogAttr("originalSize", originalSize), - ) - } else if wasCompressed { - reply.Data = compressedData - - // Set header to indicate manual compression - grpc.SetHeader(ctx, metadata.Pairs(common.LavaCompressionHeader, common.LavaCompressionGzip)) - } - } - } + // Note: gRPC compression is handled automatically by the gRPC layer + // when the client uses grpc.UseCompressor(gzip.Name) utils.LavaFormatInfo("Done handling relay request from consumer", utils.Attribute{Key: "GUID", Value: ctx}, @@ -463,7 +438,6 @@ func (rpcps *RPCProviderServer) Relay(ctx context.Context, request *pairingtypes utils.Attribute{Key: "request.cu", Value: request.RelaySession.CuSum}, utils.Attribute{Key: "relay_timeout", Value: common.GetRemainingTimeoutFromContext(ctx)}, utils.Attribute{Key: "timeTaken", Value: processingTime}, - utils.Attribute{Key: "consumerSupportsCompression", Value: consumerSupportsCompression}, ) return reply, rpcps.handleRelayErrorStatus(err) } diff --git a/protocol/rpcsmartrouter/rpcsmartrouter.go b/protocol/rpcsmartrouter/rpcsmartrouter.go index b7d22a7fac..fd8f792cdb 100644 --- a/protocol/rpcsmartrouter/rpcsmartrouter.go +++ b/protocol/rpcsmartrouter/rpcsmartrouter.go @@ -689,7 +689,7 @@ rpcsmartrouter smartrouter_examples/full_smartrouter_example.yml --cache-be "127 } lavasession.AllowGRPCCompressionForConsumerProviderCommunication = viper.GetBool(lavasession.AllowGRPCCompressionFlag) if lavasession.AllowGRPCCompressionForConsumerProviderCommunication { - utils.LavaFormatInfo("AllowGRPCCompressionForConsumerProviderCommunication is set to true, messages will be compressed", utils.Attribute{Key: lavasession.AllowGRPCCompressionFlag, Value: lavasession.AllowGRPCCompressionForConsumerProviderCommunication}) + utils.LavaFormatInfo("gRPC compression enabled, relay messages will use gzip compression", utils.Attribute{Key: lavasession.AllowGRPCCompressionFlag, Value: lavasession.AllowGRPCCompressionForConsumerProviderCommunication}) } var rpcEndpoints []*lavasession.RPCEndpoint @@ -992,7 +992,7 @@ rpcsmartrouter smartrouter_examples/full_smartrouter_example.yml --cache-be "127 cmdRPCSmartRouter.Flags().Uint(common.MaximumConcurrentProvidersFlagName, 3, "max number of concurrent providers to communicate with") cmdRPCSmartRouter.MarkFlagRequired(common.GeolocationFlag) cmdRPCSmartRouter.Flags().Bool(lavasession.AllowInsecureConnectionToProvidersFlag, false, "allow insecure provider-dialing. used for development and testing") - cmdRPCSmartRouter.Flags().Bool(lavasession.AllowGRPCCompressionFlag, false, "allow messages to be compressed when communicating between the consumer and provider") + cmdRPCSmartRouter.Flags().Bool(lavasession.AllowGRPCCompressionFlag, false, "enable gzip compression for gRPC messages between consumer and provider (reduces bandwidth, adds CPU overhead)") cmdRPCSmartRouter.Flags().Uint64Var(&lavasession.MaximumStreamsOverASingleConnection, lavasession.MaximumStreamsOverASingleConnectionFlag, lavasession.DefaultMaximumStreamsOverASingleConnection, "maximum number of parallel streams over a single provider connection") cmdRPCSmartRouter.Flags().Bool(common.TestModeFlagName, false, "test mode causes rpcconsumer to send dummy data and print all of the metadata in it's listeners") cmdRPCSmartRouter.Flags().String(performance.PprofAddressFlagName, "", "pprof server address, used for code profiling") diff --git a/protocol/rpcsmartrouter/rpcsmartrouter_compression_test.go b/protocol/rpcsmartrouter/rpcsmartrouter_compression_test.go deleted file mode 100644 index 7debcef6ea..0000000000 --- a/protocol/rpcsmartrouter/rpcsmartrouter_compression_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package rpcsmartrouter - -import ( - "context" - "strings" - "testing" - "time" - - "github.com/lavanet/lava/v5/protocol/chainlib/extensionslib" - "github.com/lavanet/lava/v5/protocol/common" - "github.com/lavanet/lava/v5/protocol/lavasession" - "github.com/lavanet/lava/v5/protocol/metrics" - "github.com/lavanet/lava/v5/utils/rand" - "github.com/lavanet/lava/v5/utils/sigs" - pairingtypes "github.com/lavanet/lava/v5/x/pairing/types" - spectypes "github.com/lavanet/lava/v5/x/spec/types" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" -) - -// TestSmartRouterCompressionHeaderLogic tests that smart router adds compression header when flag is enabled -func TestSmartRouterCompressionHeaderLogic(t *testing.T) { - tests := []struct { - name string - flagEnabled bool - expectHeader bool - }{ - { - name: "Flag enabled - header should be added", - flagEnabled: true, - expectHeader: true, - }, - { - name: "Flag disabled - no header", - flagEnabled: false, - expectHeader: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Simulate smart router creating metadata (same logic as in relayInner) - metadataAdd := metadata.New(map[string]string{ - common.IP_FORWARDING_HEADER_NAME: "test-token", - common.LAVA_CONSUMER_PROCESS_GUID: "12345", - common.LAVA_LB_UNIQUE_ID_HEADER: "67890", - }) - - // Simulate the compression header logic from relayInner - if tt.flagEnabled { - metadataAdd.Set(common.LavaCompressionSupportHeader, "true") - } - - // Verify header presence - headerValues := metadataAdd.Get(common.LavaCompressionSupportHeader) - if tt.expectHeader { - require.Len(t, headerValues, 1, "Should have compression support header") - require.Equal(t, "true", headerValues[0], "Header value should be 'true'") - } else { - require.Empty(t, headerValues, "Should not have compression support header") - } - }) - } -} - -// TestSmartRouterHandlesDecompressionError tests smart router handles invalid compressed data -func TestSmartRouterHandlesDecompressionError(t *testing.T) { - rand.InitRandomSeed() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - _, providerAccount := sigs.GenerateFloatingKey() - providerPublicAddress := providerAccount.String() - consumeSK, consumerAccount := sigs.GenerateFloatingKey() - - // Enable compression - lavasession.AllowGRPCCompressionForConsumerProviderCommunication = true - - // Invalid compressed data - invalidCompressedData := []byte("this is not valid gzip data") - - relayerMock := NewMockRelayerClient(ctrl) - relayerMock.EXPECT().Relay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, in *pairingtypes.RelayRequest, opts ...grpc.CallOption) (*pairingtypes.RelayReply, error) { - // Set compression header (claiming it's compressed) - for _, opt := range opts { - if headerOpt, ok := opt.(grpc.HeaderCallOption); ok { - *headerOpt.HeaderAddr = metadata.Pairs(common.LavaCompressionHeader, common.LavaCompressionGzip) - } - } - - // Return invalid data - reply := &pairingtypes.RelayReply{ - Data: invalidCompressedData, - FinalizedBlocksHashes: []byte(`{"0":"hash0"}`), - } - return reply, nil - }, - ).Times(1) - - rpcSmartRouterServer, chainParser := createRpcSmartRouter( - t, ctrl, context.Background(), consumeSK, consumerAccount, - providerPublicAddress, relayerMock, "LAV1", - spectypes.APIInterfaceTendermintRPC, 100, 1, "lava", - ) - - chainMsg, err := chainParser.ParseMsg("", []byte(`{"jsonrpc":"2.0","method":"status","params":[],"id":1}`), "", nil, extensionslib.ExtensionInfo{}) - require.NoError(t, err) - - singleConsumerSession := &lavasession.SingleConsumerSession{ - EndpointConnection: &lavasession.EndpointConnection{Client: relayerMock}, - } - singleConsumerSession.Parent = &lavasession.ConsumerSessionsWithProvider{ - PublicLavaAddress: providerPublicAddress, - PairingEpoch: 100, - } - - relayResult := &common.RelayResult{ - ProviderInfo: common.ProviderInfo{ProviderAddress: providerPublicAddress}, - Request: &pairingtypes.RelayRequest{ - RelayData: &pairingtypes.RelayPrivateData{RequestBlock: 0}, - RelaySession: &pairingtypes.RelaySession{}, - }, - } - - // Call relayInner - should fail with decompression error - _, err, needsBackoff := rpcSmartRouterServer.relayInner( - context.Background(), - singleConsumerSession, - relayResult, - 30*time.Second, - chainMsg, - "test-token", - &metrics.RelayMetrics{}, - ) - - // Should return an error due to invalid gzip data - require.Error(t, err, "Should fail with decompression error") - require.Contains(t, err.Error(), "failed to create gzip reader", "Error should mention gzip") - require.False(t, needsBackoff, "Should not need backoff for decompression error") -} - -// TestSmartRouterCompressionRoundTrip tests the complete compression/decompression cycle -func TestSmartRouterCompressionRoundTrip(t *testing.T) { - // This test verifies the data integrity through compression and decompression - originalData := []byte(strings.Repeat("test data for compression ", 50000)) - - // Compress (as provider would) - compressedData, wasCompressed, err := common.CompressData(originalData, common.CompressionThreshold) - require.NoError(t, err) - require.True(t, wasCompressed, "Data should be compressed") - require.Less(t, len(compressedData), len(originalData), "Compressed should be smaller") - - // Decompress (as smart router would) - decompressedData, err := common.DecompressData(compressedData) - require.NoError(t, err) - - // Verify data integrity - require.Equal(t, originalData, decompressedData, "Decompressed data should match original") -} - -// TestSmartRouterNoCompressionWhenBelowThreshold tests that small responses aren't compressed -func TestSmartRouterNoCompressionWhenBelowThreshold(t *testing.T) { - smallData := []byte(strings.Repeat("a", 1000)) - - // Try to compress - result, wasCompressed, err := common.CompressData(smallData, common.CompressionThreshold) - require.NoError(t, err) - require.False(t, wasCompressed, "Small data should not be compressed") - require.Equal(t, smallData, result, "Should return original data when not compressed") -} - -// TestSmartRouterUncompressedResponseHandling tests that uncompressed responses work normally -func TestSmartRouterUncompressedResponseHandling(t *testing.T) { - // Simulate smart router receiving response without compression header - responseData := []byte(`{"jsonrpc":"2.0","result":{"status":"ok"},"id":1}`) - responseHeader := metadata.MD{} // Empty metadata - no compression header - - // Check for compression header - lavaCompressionValues := responseHeader.Get(common.LavaCompressionHeader) - shouldDecompress := len(lavaCompressionValues) > 0 && lavaCompressionValues[0] == common.LavaCompressionGzip - - // Verify no decompression is attempted - require.False(t, shouldDecompress, "Should not attempt decompression when header is missing") - - // Simulate the actual code logic - var finalData []byte - if shouldDecompress { - // This branch should NOT be taken - t.Fatal("Should not attempt to decompress when header is missing") - } else { - // This branch should be taken - use data as-is - finalData = responseData - } - - // Verify data is unchanged - require.Equal(t, responseData, finalData, "Data should be used as-is when not compressed") -} diff --git a/protocol/rpcsmartrouter/rpcsmartrouter_server.go b/protocol/rpcsmartrouter/rpcsmartrouter_server.go index f7d34704c2..1b93980962 100644 --- a/protocol/rpcsmartrouter/rpcsmartrouter_server.go +++ b/protocol/rpcsmartrouter/rpcsmartrouter_server.go @@ -1366,10 +1366,9 @@ func (rpcss *RPCSmartRouterServer) relayInner(ctx context.Context, singleConsume common.LAVA_LB_UNIQUE_ID_HEADER: singleConsumerSession.EndpointConnection.GetLbUniqueId(), }) - // Add compression support header if enabled - if lavasession.AllowGRPCCompressionForConsumerProviderCommunication { - metadataAdd.Set(common.LavaCompressionSupportHeader, "true") - } + // Note: gRPC compression is handled automatically by the gRPC layer + // when AllowGRPCCompressionForConsumerProviderCommunication is enabled + // via grpc.UseCompressor(gzip.Name) in ConnectGRPCClient utils.LavaFormatTrace("Sending relay to provider", utils.LogAttr("GUID", ctx), @@ -1393,17 +1392,7 @@ func (rpcss *RPCSmartRouterServer) relayInner(ctx context.Context, singleConsume reply, err = endpointClient.Relay(connectCtx, relayRequest, grpc.Header(&responseHeader), grpc.Trailer(&relayResult.ProviderTrailer)) relayLatency = time.Since(relaySentTime) - // Decompress response if compressed - if reply != nil && reply.Data != nil { - if lavaCompressionValues := responseHeader.Get(common.LavaCompressionHeader); len(lavaCompressionValues) > 0 && lavaCompressionValues[0] == common.LavaCompressionGzip { - decompressedData, decompressErr := common.DecompressData(reply.Data) - if decompressErr != nil { - utils.LavaFormatError("Failed to decompress response", decompressErr, utils.LogAttr("GUID", ctx)) - return nil, 0, decompressErr, false - } - reply.Data = decompressedData - } - } + // Note: gRPC decompression is handled automatically by the gRPC layer providerUniqueId := relayResult.ProviderTrailer.Get(chainlib.RpcProviderUniqueIdHeader) if len(providerUniqueId) > 0 {