diff --git a/cmd/api/config.go b/cmd/api/config.go index 9818fa01..4a2f6885 100644 --- a/cmd/api/config.go +++ b/cmd/api/config.go @@ -51,6 +51,7 @@ type HTTPConfig struct { RateLimInterval string `default:"1s"` MaxRequestPerInterval uint64 `default:"10"` + APIKey string `default:""` // if client passes the key it will not be affected by rate limiter } // GatewayConfig contains configuration for the Gateway. diff --git a/cmd/api/main.go b/cmd/api/main.go index bbfac351..611ac712 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -465,6 +465,7 @@ func createAPIServer( httpConfig.MaxRequestPerInterval, rateLimInterval, supportedChainIDs, + httpConfig.APIKey, ) if err != nil { return nil, fmt.Errorf("configuring router: %s", err) diff --git a/docker/deployed/mainnet/api/config.json b/docker/deployed/mainnet/api/config.json index 789a8292..0af2ba27 100644 --- a/docker/deployed/mainnet/api/config.json +++ b/docker/deployed/mainnet/api/config.json @@ -4,6 +4,7 @@ "Port": "8080", "RateLimInterval": "1s", "MaxRequestPerInterval": 10, + "ApiKey" : "${HTTP_RATE_LIMITER_API_KEY}", "TLSCert": "${VALIDATOR_TLS_CERT}", "TLSKey": "${VALIDATOR_TLS_KEY}" }, diff --git a/docker/deployed/staging/api/config.json b/docker/deployed/staging/api/config.json index 420059ad..d516576d 100644 --- a/docker/deployed/staging/api/config.json +++ b/docker/deployed/staging/api/config.json @@ -3,6 +3,7 @@ "Port": "8080", "RateLimInterval": "1s", "MaxRequestPerInterval": 10, + "ApiKey" : "${HTTP_RATE_LIMITER_API_KEY}", "TLSCert": "${VALIDATOR_TLS_CERT}", "TLSKey": "${VALIDATOR_TLS_KEY}" }, diff --git a/docker/deployed/testnet/api/config.json b/docker/deployed/testnet/api/config.json index e37e7676..a12ffbdd 100644 --- a/docker/deployed/testnet/api/config.json +++ b/docker/deployed/testnet/api/config.json @@ -4,6 +4,7 @@ "Port": "8080", "RateLimInterval": "1s", "MaxRequestPerInterval": 10, + "ApiKey" : "${HTTP_RATE_LIMITER_API_KEY}", "TLSCert": "${VALIDATOR_TLS_CERT}", "TLSKey": "${VALIDATOR_TLS_KEY}" }, @@ -51,25 +52,6 @@ "ChainStackCollectFrequency": "15m" }, "Chains": [ - { - "Name": "Ethereum Goerli", - "ChainID": 5, - "Registry": { - "EthEndpoint": "wss://eth-goerli.alchemyapi.io/v2/${VALIDATOR_ALCHEMY_ETHEREUM_GOERLI_API_KEY}", - "ContractAddress": "0xDA8EA22d092307874f30A1F277D1388dca0BA97a" - }, - "EventFeed": { - "ChainAPIBackoff": "15s", - "NewBlockPollFreq": "10s", - "MinBlockDepth": 1, - "PersistEvents": true - }, - "EventProcessor": { - "BlockFailedExecutionBackoff": "10s", - "DedupExecutedTxns": true - }, - "HashCalculationStep": 150 - }, { "Name": "Ethereum Sepolia", "ChainID": 11155111, diff --git a/internal/router/middlewares/ratelim.go b/internal/router/middlewares/ratelim.go index dd11f083..6369074b 100644 --- a/internal/router/middlewares/ratelim.go +++ b/internal/router/middlewares/ratelim.go @@ -4,10 +4,12 @@ import ( "fmt" "net" "net/http" + "strconv" "strings" "time" "github.com/gorilla/mux" + "github.com/sethvargo/go-limiter" "github.com/sethvargo/go-limiter/httplimit" "github.com/sethvargo/go-limiter/memorystore" ) @@ -22,6 +24,7 @@ type RateLimiterConfig struct { type RateLimiterRouteConfig struct { MaxRPI uint64 Interval time.Duration + APIKey string } // RateLimitController creates a new middleware to rate limit requests. @@ -47,7 +50,7 @@ func RateLimitController(cfg RateLimiterConfig) (mux.MiddlewareFunc, error) { }, nil } -func createRateLimiter(cfg RateLimiterRouteConfig, kf httplimit.KeyFunc) (*httplimit.Middleware, error) { +func createRateLimiter(cfg RateLimiterRouteConfig, kf httplimit.KeyFunc) (*middleware, error) { defaultStore, err := memorystore.New(&memorystore.Config{ Tokens: cfg.MaxRPI, Interval: cfg.Interval, @@ -55,11 +58,12 @@ func createRateLimiter(cfg RateLimiterRouteConfig, kf httplimit.KeyFunc) (*httpl if err != nil { return nil, fmt.Errorf("creating default memory: %s", err) } - m, err := httplimit.NewMiddleware(defaultStore, kf) - if err != nil { - return nil, fmt.Errorf("creating default httplimiter: %s", err) - } - return m, nil + + return &middleware{ + store: defaultStore, + keyFunc: kf, + apiKey: cfg.APIKey, + }, nil } func extractClientIP(r *http.Request) (string, error) { @@ -77,3 +81,62 @@ func extractClientIP(r *http.Request) (string, error) { } return ip, nil } + +type middleware struct { + store limiter.Store + keyFunc httplimit.KeyFunc + + // clients with key are not affected by rate limiter + apiKey string +} + +// Handle returns the HTTP handler as a middleware. This handler calls Take() on +// the store and sets the common rate limiting headers. If the take is +// successful, the remaining middleware is called. If take is unsuccessful, the +// middleware chain is halted and the function renders a 429 to the caller with +// metadata about when it's safe to retry. +func (m *middleware) Handle(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Call the key function - if this fails, it's an internal server error. + key, err := m.keyFunc(r) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + // skip rate limiting checks if api key is provided + if key := r.Header.Get("Api-Key"); key != "" && m.apiKey != "" { + if strings.EqualFold(key, m.apiKey) { + next.ServeHTTP(w, r) + return + } + } + + // Take from the store. + limit, remaining, reset, ok, err := m.store.Take(ctx, key) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + resetTime := time.Unix(0, int64(reset)).UTC().Format(time.RFC1123) + + // Set headers (we do this regardless of whether the request is permitted). + w.Header().Set("X-RateLimit-Limit", strconv.FormatUint(limit, 10)) + w.Header().Set("X-RateLimit-Remaining", strconv.FormatUint(remaining, 10)) + w.Header().Set("X-RateLimit-Reset", resetTime) + + // Fail if there were no tokens remaining. + if !ok { + w.Header().Set("Retry-After", resetTime) + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + + // If we got this far, we're allowed to continue, so call the next middleware + // in the stack to continue processing. + next.ServeHTTP(w, r) + }) +} diff --git a/internal/router/middlewares/ratelim_test.go b/internal/router/middlewares/ratelim_test.go index c3da0c52..5cc03f16 100644 --- a/internal/router/middlewares/ratelim_test.go +++ b/internal/router/middlewares/ratelim_test.go @@ -20,6 +20,7 @@ func TestLimit1IP(t *testing.T) { callRPS int limitRPS int forwardedFor bool + allow bool } tests := []testCase{ @@ -28,6 +29,9 @@ func TestLimit1IP(t *testing.T) { {name: "success", callRPS: 100, limitRPS: 500, forwardedFor: false}, {name: "block-me", callRPS: 1000, limitRPS: 500, forwardedFor: false}, + + {name: "allow-me", callRPS: 1000, limitRPS: 500, forwardedFor: false, allow: true}, + {name: "forwarded-allow-me", callRPS: 1000, limitRPS: 500, forwardedFor: true, allow: true}, } for _, tc := range tests { @@ -41,20 +45,27 @@ func TestLimit1IP(t *testing.T) { Interval: time.Second, }, } - rlcm, err := RateLimitController(cfg) - require.NoError(t, err) - rlc := rlcm(dummyHandler{}) ctx := context.Background() r, err := http.NewRequestWithContext(ctx, "", "", nil) require.NoError(t, err) + ip := uuid.NewString() if tc.forwardedFor { - r.Header.Set("X-Forwarded-For", uuid.NewString()) + r.Header.Set("X-Forwarded-For", ip) } else { - r.RemoteAddr = uuid.NewString() + ":1234" + r.RemoteAddr = ip + ":1234" + } + + if tc.allow { + r.Header.Set("Api-Key", "MYSECRETKEY") + cfg.Default.APIKey = "MYSECRETKEY" } + rlcm, err := RateLimitController(cfg) + require.NoError(t, err) + rlc := rlcm(dummyHandler{}) + res := httptest.NewRecorder() // Verify that after some seconds making requests with the configured @@ -62,7 +73,7 @@ func TestLimit1IP(t *testing.T) { // - If callRPS < limitRPS, we never get a 429. // - If callRPS > limitRPS, we eventually should see a 429. assertFunc := require.Eventually - if tc.callRPS < tc.limitRPS { + if tc.callRPS < tc.limitRPS || tc.allow { assertFunc = require.Never } assertFunc(t, func() bool { diff --git a/internal/router/router.go b/internal/router/router.go index 29f07d6f..17546c21 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -19,6 +19,7 @@ func ConfiguredRouter( maxRPI uint64, rateLimInterval time.Duration, supportedChainIDs []tableland.ChainID, + apiKey string, ) (*Router, error) { // General router configuration. router := newRouter() @@ -28,6 +29,7 @@ func ConfiguredRouter( Default: middlewares.RateLimiterRouteConfig{ MaxRPI: maxRPI, Interval: rateLimInterval, + APIKey: apiKey, }, } rateLim, err := middlewares.RateLimitController(cfg) diff --git a/tests/fullstack/fullstack.go b/tests/fullstack/fullstack.go index 1932a4ce..1a4304ae 100644 --- a/tests/fullstack/fullstack.go +++ b/tests/fullstack/fullstack.go @@ -131,7 +131,7 @@ func CreateFullStack(t *testing.T, deps Deps) FullStack { require.NoError(t, err) } - router, err := router.ConfiguredRouter(gatewayService, 10, time.Second, []tableland.ChainID{ChainID}) + router, err := router.ConfiguredRouter(gatewayService, 10, time.Second, []tableland.ChainID{ChainID}, "") require.NoError(t, err) server := httptest.NewServer(router.Handler())