diff --git a/api/client/BUILD.bazel b/api/client/BUILD.bazel index f5ddc9bb6c81..131d74c4f11a 100644 --- a/api/client/BUILD.bazel +++ b/api/client/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "client.go", "errors.go", "options.go", + "transport.go", ], importpath = "github.com/OffchainLabs/prysm/v6/api/client", visibility = ["//visibility:public"], @@ -14,7 +15,13 @@ go_library( go_test( name = "go_default_test", - srcs = ["client_test.go"], + srcs = [ + "client_test.go", + "transport_test.go", + ], embed = [":go_default_library"], - deps = ["//testing/require:go_default_library"], + deps = [ + "//testing/assert:go_default_library", + "//testing/require:go_default_library", + ], ) diff --git a/api/client/transport.go b/api/client/transport.go new file mode 100644 index 000000000000..af29e1168cda --- /dev/null +++ b/api/client/transport.go @@ -0,0 +1,25 @@ +package client + +import "net/http" + +// CustomHeadersTransport adds custom headers to each request +type CustomHeadersTransport struct { + base http.RoundTripper + headers map[string][]string +} + +func NewCustomHeadersTransport(base http.RoundTripper, headers map[string][]string) *CustomHeadersTransport { + return &CustomHeadersTransport{ + base: base, + headers: headers, + } +} + +func (t *CustomHeadersTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for header, values := range t.headers { + for _, value := range values { + req.Header.Add(header, value) + } + } + return t.base.RoundTrip(req) +} diff --git a/api/client/transport_test.go b/api/client/transport_test.go new file mode 100644 index 000000000000..0a2eca3103f0 --- /dev/null +++ b/api/client/transport_test.go @@ -0,0 +1,25 @@ +package client + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/OffchainLabs/prysm/v6/testing/assert" + "github.com/OffchainLabs/prysm/v6/testing/require" +) + +type noopTransport struct{} + +func (*noopTransport) RoundTrip(*http.Request) (*http.Response, error) { + return nil, nil +} + +func TestRoundTrip(t *testing.T) { + tr := &CustomHeadersTransport{base: &noopTransport{}, headers: map[string][]string{"key1": []string{"value1", "value2"}, "key2": []string{"value3"}}} + req := httptest.NewRequest("GET", "http://foo", nil) + _, err := tr.RoundTrip(req) + require.NoError(t, err) + assert.DeepEqual(t, []string{"value1", "value2"}, req.Header.Values("key1")) + assert.DeepEqual(t, []string{"value3"}, req.Header.Values("key2")) +} diff --git a/changelog/radek_rest-custom-headers.md b/changelog/radek_rest-custom-headers.md new file mode 100644 index 000000000000..e0ca1b18da12 --- /dev/null +++ b/changelog/radek_rest-custom-headers.md @@ -0,0 +1,3 @@ +### Added + +- Allow custom headers in validator client HTTP requests. \ No newline at end of file diff --git a/cmd/validator/flags/flags.go b/cmd/validator/flags/flags.go index 51414da2c01e..003ad0d51784 100644 --- a/cmd/validator/flags/flags.go +++ b/cmd/validator/flags/flags.go @@ -45,6 +45,13 @@ var ( Usage: "Beacon node REST API provider endpoint.", Value: "http://127.0.0.1:3500", } + // BeaconRESTApiHeaders defines a list of headers to send with all HTTP requests to the beacon node. + BeaconRESTApiHeaders = &cli.StringFlag{ + Name: "beacon-rest-api-headers", + Usage: `Comma-separated list of key value pairs to pass as headers for all HTTP calls to the beacon node. + To provide multiple values for the same key, specify the same key for each value. + Example: --grpc-headers=key1=value1,key1=value2,key2=value3`, + } // CertFlag defines a flag for the node's TLS certificate. CertFlag = &cli.StringFlag{ Name: "tls-cert", diff --git a/cmd/validator/main.go b/cmd/validator/main.go index b7617eb7b914..3a212cfe2581 100644 --- a/cmd/validator/main.go +++ b/cmd/validator/main.go @@ -51,6 +51,7 @@ func startNode(ctx *cli.Context) error { var appFlags = []cli.Flag{ flags.BeaconRPCProviderFlag, flags.BeaconRESTApiProviderFlag, + flags.BeaconRESTApiHeaders, flags.CertFlag, flags.GraffitiFlag, flags.DisablePenaltyRewardLogFlag, diff --git a/cmd/validator/usage.go b/cmd/validator/usage.go index 00c6c13a6810..4030c0221004 100644 --- a/cmd/validator/usage.go +++ b/cmd/validator/usage.go @@ -93,6 +93,7 @@ var appHelpFlagGroups = []flagGroup{ Flags: []cli.Flag{ flags.CertFlag, flags.BeaconRPCProviderFlag, + flags.BeaconRESTApiHeaders, flags.EnableRPCFlag, flags.RPCHost, flags.RPCPort, diff --git a/validator/accounts/cli_manager.go b/validator/accounts/cli_manager.go index ed4851d3034b..ad51f5dc99a3 100644 --- a/validator/accounts/cli_manager.go +++ b/validator/accounts/cli_manager.go @@ -84,6 +84,7 @@ func (acm *CLIManager) prepareBeaconClients(ctx context.Context) (*iface.Validat conn := validatorHelpers.NewNodeConnection( grpcConn, acm.beaconApiEndpoint, + nil, acm.beaconApiTimeout, ) diff --git a/validator/client/service.go b/validator/client/service.go index 22487b386bfa..e0e9f9686862 100644 --- a/validator/client/service.go +++ b/validator/client/service.go @@ -6,6 +6,7 @@ import ( "strings" "time" + api "github.com/OffchainLabs/prysm/v6/api/client" eventClient "github.com/OffchainLabs/prysm/v6/api/client/event" grpcutil "github.com/OffchainLabs/prysm/v6/api/grpc" "github.com/OffchainLabs/prysm/v6/async/event" @@ -79,6 +80,7 @@ type Config struct { BeaconNodeGRPCEndpoint string BeaconNodeCert string BeaconApiEndpoint string + BeaconApiHeaders map[string][]string BeaconApiTimeout time.Duration Graffiti string GraffitiStruct *graffiti.Graffiti @@ -142,6 +144,7 @@ func NewValidatorService(ctx context.Context, cfg *Config) (*ValidatorService, e s.conn = validatorHelpers.NewNodeConnection( grpcConn, cfg.BeaconApiEndpoint, + cfg.BeaconApiHeaders, cfg.BeaconApiTimeout, ) @@ -185,8 +188,9 @@ func (v *ValidatorService) Start() { return } + headersTransport := api.NewCustomHeadersTransport(http.DefaultTransport, v.conn.GetBeaconApiHeaders()) restHandler := beaconApi.NewBeaconApiRestHandler( - http.Client{Timeout: v.conn.GetBeaconApiTimeout(), Transport: otelhttp.NewTransport(http.DefaultTransport)}, + http.Client{Timeout: v.conn.GetBeaconApiTimeout(), Transport: otelhttp.NewTransport(headersTransport)}, hosts[0], ) diff --git a/validator/helpers/node_connection.go b/validator/helpers/node_connection.go index 8f2de45947d8..c9bf4c3dab3b 100644 --- a/validator/helpers/node_connection.go +++ b/validator/helpers/node_connection.go @@ -10,6 +10,7 @@ import ( type NodeConnection interface { GetGrpcClientConn() *grpc.ClientConn GetBeaconApiUrl() string + GetBeaconApiHeaders() map[string][]string GetBeaconApiTimeout() time.Duration dummy() } @@ -17,6 +18,7 @@ type NodeConnection interface { type nodeConnection struct { grpcClientConn *grpc.ClientConn beaconApiUrl string + beaconApiHeaders map[string][]string beaconApiTimeout time.Duration } @@ -28,16 +30,21 @@ func (c *nodeConnection) GetBeaconApiUrl() string { return c.beaconApiUrl } +func (c *nodeConnection) GetBeaconApiHeaders() map[string][]string { + return c.beaconApiHeaders +} + func (c *nodeConnection) GetBeaconApiTimeout() time.Duration { return c.beaconApiTimeout } func (*nodeConnection) dummy() {} -func NewNodeConnection(grpcConn *grpc.ClientConn, beaconApiUrl string, beaconApiTimeout time.Duration) NodeConnection { +func NewNodeConnection(grpcConn *grpc.ClientConn, beaconApiUrl string, beaconApiHeaders map[string][]string, beaconApiTimeout time.Duration) NodeConnection { conn := &nodeConnection{} conn.grpcClientConn = grpcConn conn.beaconApiUrl = beaconApiUrl + conn.beaconApiHeaders = beaconApiHeaders conn.beaconApiTimeout = beaconApiTimeout return conn } diff --git a/validator/node/node.go b/validator/node/node.go index 48f811d65e05..652b9946a3c3 100644 --- a/validator/node/node.go +++ b/validator/node/node.go @@ -433,6 +433,7 @@ func (c *ValidatorClient) registerValidatorService(cliCtx *cli.Context) error { BeaconNodeGRPCEndpoint: cliCtx.String(flags.BeaconRPCProviderFlag.Name), BeaconNodeCert: cliCtx.String(flags.CertFlag.Name), BeaconApiEndpoint: cliCtx.String(flags.BeaconRESTApiProviderFlag.Name), + BeaconApiHeaders: parseBeaconApiHeaders(cliCtx.String(flags.BeaconRESTApiHeaders.Name)), BeaconApiTimeout: time.Second * 30, Graffiti: g.ParseHexGraffiti(cliCtx.String(flags.GraffitiFlag.Name)), GraffitiStruct: graffitiStruct, @@ -552,6 +553,7 @@ func (c *ValidatorClient) registerRPCService(cliCtx *cli.Context) error { GRPCHeaders: strings.Split(cliCtx.String(flags.GRPCHeadersFlag.Name), ","), BeaconNodeGRPCEndpoint: cliCtx.String(flags.BeaconRPCProviderFlag.Name), BeaconApiEndpoint: cliCtx.String(flags.BeaconRESTApiProviderFlag.Name), + BeaconAPIHeaders: parseBeaconApiHeaders(cliCtx.String(flags.BeaconRESTApiHeaders.Name)), BeaconApiTimeout: time.Second * 30, BeaconNodeCert: cliCtx.String(flags.CertFlag.Name), DB: c.db, @@ -636,3 +638,19 @@ func clearDB(ctx context.Context, dataDir string, force bool, isDatabaseMinimal return nil } + +func parseBeaconApiHeaders(rawHeaders string) map[string][]string { + result := make(map[string][]string) + pairs := strings.Split(rawHeaders, ",") + for _, pair := range pairs { + key, value, found := strings.Cut(pair, "=") + if !found { + // Skip malformed pairs + continue + } + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + result[key] = append(result[key], value) + } + return result +} diff --git a/validator/rpc/BUILD.bazel b/validator/rpc/BUILD.bazel index 64a690d69d80..f7e4b41fe1d5 100644 --- a/validator/rpc/BUILD.bazel +++ b/validator/rpc/BUILD.bazel @@ -23,6 +23,7 @@ go_library( ], deps = [ "//api:go_default_library", + "//api/client:go_default_library", "//api/grpc:go_default_library", "//api/pagination:go_default_library", "//api/server:go_default_library", diff --git a/validator/rpc/beacon.go b/validator/rpc/beacon.go index 596eef0ae906..d51261cd58e8 100644 --- a/validator/rpc/beacon.go +++ b/validator/rpc/beacon.go @@ -3,6 +3,7 @@ package rpc import ( "net/http" + api "github.com/OffchainLabs/prysm/v6/api/client" grpcutil "github.com/OffchainLabs/prysm/v6/api/grpc" ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1" "github.com/OffchainLabs/prysm/v6/validator/client" @@ -52,11 +53,13 @@ func (s *Server) registerBeaconClient() error { conn := validatorHelpers.NewNodeConnection( grpcConn, s.beaconApiEndpoint, + s.beaconApiHeaders, s.beaconApiTimeout, ) + headersTransport := api.NewCustomHeadersTransport(http.DefaultTransport, conn.GetBeaconApiHeaders()) restHandler := beaconApi.NewBeaconApiRestHandler( - http.Client{Timeout: s.beaconApiTimeout, Transport: otelhttp.NewTransport(http.DefaultTransport)}, + http.Client{Timeout: s.beaconApiTimeout, Transport: otelhttp.NewTransport(headersTransport)}, s.beaconApiEndpoint, ) diff --git a/validator/rpc/server.go b/validator/rpc/server.go index 5e9b4da70376..85322a910fd1 100644 --- a/validator/rpc/server.go +++ b/validator/rpc/server.go @@ -34,6 +34,7 @@ type Config struct { GRPCHeaders []string BeaconNodeGRPCEndpoint string BeaconApiEndpoint string + BeaconAPIHeaders map[string][]string BeaconApiTimeout time.Duration BeaconNodeCert string DB db.Database @@ -64,6 +65,7 @@ type Server struct { authTokenPath string beaconNodeCert string beaconApiEndpoint string + beaconApiHeaders map[string][]string beaconNodeEndpoint string healthClient ethpb.HealthClient nodeClient iface.NodeClient @@ -103,6 +105,7 @@ func NewServer(ctx context.Context, cfg *Config) *Server { wallet: cfg.Wallet, beaconApiTimeout: cfg.BeaconApiTimeout, beaconApiEndpoint: cfg.BeaconApiEndpoint, + beaconApiHeaders: cfg.BeaconAPIHeaders, beaconNodeEndpoint: cfg.BeaconNodeGRPCEndpoint, router: cfg.Router, }