Skip to content

Commit e8f679b

Browse files
authored
feat: Add Sb-Forwarded-For header and IP-based rate limiting (#2295)
## What kind of change does this PR introduce? This PR adds support for IP address forwarding using a new header, `Sb-Forwarded-For`, optionally gated by `GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED`. When this feature is enabled, both `utilities.GetIPAddress` and rate limiting will use the first value of the `Sb-Forwarded-For` header as the IP address/rate limiting key. If the feature is disabled or the `Sb-Forwarded-For` header contains an invalid value, Auth will fall back to existing behavior. ## What is the current behavior? There are currently two paths along which users are likely to use IP address information. The first is IP tracking (e.g., logging, MFA challenge validation, and CAPTCHA challenge validation). The second is rate limiting. Both of these follow slightly different logical paths, relying on the `X-Forwarded-For` header explicitly in the former case and a separate rate limiting key header in the latter. The presence of these two paths results in some friction for users. `X-Forwarded-For` can be (and frequently is) rewritten by proxies or otherwise spoofed, and there is no guarantee that a rate limiting key in the rate limit header is an IP address. ## What is the new behavior? The API uses a new middleware, `sbff.Middleware`, that parses the `Sb-Forwarded-For` header and inserts it into the request context if `GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED` is true. Consumers of the `Sb-Forwarded-For` header can use `sbff.GetIPAddress` to retrieve the parsed IP address. `utilities.GetIPAddress` will prefer the result of `sbff.GetIPAddress` as the end-user IP address if the feature is enabled and the `Sb-Forwarded-For` header contains a value value. Similarly, Auth will use the end user IP address as determined by `sbff.GetIPAddress` as the rate limiting key under the same circumstances. If the feature is not enabled or the `Sb-Forwarded-For` header is absent or otherwise invalid, Auth will default to existing/legacy behavior.
1 parent c553b10 commit e8f679b

9 files changed

Lines changed: 613 additions & 7 deletions

File tree

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,12 @@ Enforce reauthentication on password update.
888888

889889
Use this to enable/disable anonymous sign-ins.
890890

891+
### IP address forwarding
892+
893+
`GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED` - `bool`
894+
895+
Enable IP address forwarding using the `Sb-Forwarded-For` HTTP request header. When enabled, Auth will parse the first value of this header as an IP address and use it for IP address tracking and rate limiting. Make sure this header is fully trusted before enabling this feature by only passing it from trustworthy clients or proxies.
896+
891897
## Endpoints
892898

893899
Auth exposes the following endpoints:

internal/api/api.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/supabase/auth/internal/mailer/templatemailer"
2020
"github.com/supabase/auth/internal/models"
2121
"github.com/supabase/auth/internal/observability"
22+
"github.com/supabase/auth/internal/sbff"
2223
"github.com/supabase/auth/internal/storage"
2324
"github.com/supabase/auth/internal/tokens"
2425
"github.com/supabase/auth/internal/utilities"
@@ -152,8 +153,17 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
152153
r := newRouter()
153154
r.UseBypass(observability.AddRequestID(globalConfig))
154155
r.UseBypass(logger)
155-
r.UseBypass(xffmw.Handler)
156156
r.UseBypass(recoverer)
157+
r.UseBypass(
158+
sbff.Middleware(
159+
&globalConfig.Security,
160+
func(r *http.Request, err error) {
161+
log := observability.GetLogEntry(r).Entry
162+
log.WithField("error", err.Error()).Warn("error processing Sb-Forwarded-For")
163+
},
164+
),
165+
)
166+
r.UseBypass(xffmw.Handler)
157167

158168
if globalConfig.API.MaxRequestDuration > 0 {
159169
r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration))

internal/api/middleware.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/supabase/auth/internal/api/shared"
2121
"github.com/supabase/auth/internal/models"
2222
"github.com/supabase/auth/internal/observability"
23+
"github.com/supabase/auth/internal/sbff"
2324
"github.com/supabase/auth/internal/security"
2425
"github.com/supabase/auth/internal/utilities"
2526

@@ -61,7 +62,7 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error {
6162

6263
var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")
6364

64-
func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
65+
func (a *API) performRateLimitingWithHeader(lmt *limiter.Limiter, req *http.Request) error {
6566
limitHeader := a.config.RateLimitHeader
6667

6768
// If no rate limit header was set, ignore rate limiting
@@ -112,6 +113,18 @@ func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error
112113
return nil
113114
}
114115

116+
func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
117+
if sbffAddr, ok := sbff.GetIPAddress(req); ok {
118+
if err := tollbooth.LimitByKeys(lmt, []string{sbffAddr}); err != nil {
119+
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
120+
}
121+
122+
return nil
123+
}
124+
125+
return a.performRateLimitingWithHeader(lmt, req)
126+
}
127+
115128
func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
116129
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
117130
return req.Context(), a.performRateLimiting(lmt, req)

internal/api/middleware_test.go

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/stretchr/testify/suite"
2020
"github.com/supabase/auth/internal/api/apierrors"
2121
"github.com/supabase/auth/internal/conf"
22+
"github.com/supabase/auth/internal/sbff"
2223
"github.com/supabase/auth/internal/storage"
2324
)
2425

@@ -415,7 +416,166 @@ func TestTimeoutResponseWriter(t *testing.T) {
415416
require.Equal(t, w1.Result(), w2.Result())
416417
}
417418

418-
func (ts *MiddlewareTestSuite) TestPerformRateLimiting() {
419+
func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithSBFF() {
420+
origRateLimitHeader := ts.Config.RateLimitHeader
421+
origSBFFEnabled := ts.Config.Security.SbForwardedForEnabled
422+
423+
defer func() {
424+
ts.Config.RateLimitHeader = origRateLimitHeader
425+
ts.Config.Security.SbForwardedForEnabled = origSBFFEnabled
426+
}()
427+
428+
ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting"
429+
ts.Config.Security.SbForwardedForEnabled = true
430+
431+
type headerSet struct {
432+
rateLimiting string
433+
sbForwardedFor string
434+
}
435+
436+
testCases := []struct {
437+
name string
438+
headerValues []headerSet
439+
expErr error
440+
}{
441+
{
442+
name: "multiple SBFF values, single rate limiting value",
443+
headerValues: []headerSet{
444+
{
445+
sbForwardedFor: "192.168.1.100",
446+
rateLimiting: "60.60.60.60",
447+
},
448+
{
449+
sbForwardedFor: "192.168.1.200",
450+
rateLimiting: "60.60.60.60",
451+
},
452+
},
453+
expErr: nil,
454+
},
455+
{
456+
name: "single SBFF value, multiple rate limiting values",
457+
headerValues: []headerSet{
458+
{
459+
sbForwardedFor: "192.168.1.100",
460+
rateLimiting: "60.60.60.60",
461+
},
462+
{
463+
sbForwardedFor: "192.168.1.100",
464+
rateLimiting: "70.70.70.70",
465+
},
466+
},
467+
expErr: apierrors.NewTooManyRequestsError(
468+
apierrors.ErrorCodeOverRequestRateLimit,
469+
"Request rate limit reached",
470+
),
471+
},
472+
{
473+
name: "no SBFF value, multiple rate limiting values",
474+
headerValues: []headerSet{
475+
{
476+
sbForwardedFor: "",
477+
rateLimiting: "60.60.60.60",
478+
},
479+
{
480+
sbForwardedFor: "",
481+
rateLimiting: "70.70.70.70",
482+
},
483+
},
484+
expErr: nil,
485+
},
486+
{
487+
name: "no SBFF value, single rate limiting value",
488+
headerValues: []headerSet{
489+
{
490+
sbForwardedFor: "",
491+
rateLimiting: "60.60.60.60",
492+
},
493+
{
494+
sbForwardedFor: "",
495+
rateLimiting: "60.60.60.60",
496+
},
497+
},
498+
expErr: apierrors.NewTooManyRequestsError(
499+
apierrors.ErrorCodeOverRequestRateLimit,
500+
"Request rate limit reached",
501+
),
502+
},
503+
{
504+
name: "invalid SBFF value, multiple rate limiting values",
505+
headerValues: []headerSet{
506+
{
507+
sbForwardedFor: "invalid",
508+
rateLimiting: "60.60.60.60",
509+
},
510+
{
511+
sbForwardedFor: "invalid",
512+
rateLimiting: "70.70.70.70",
513+
},
514+
},
515+
expErr: nil,
516+
},
517+
{
518+
name: "invalid SBFF value, single rate limiting value",
519+
headerValues: []headerSet{
520+
{
521+
sbForwardedFor: "invalid",
522+
rateLimiting: "60.60.60.60",
523+
},
524+
{
525+
sbForwardedFor: "invalid",
526+
rateLimiting: "60.60.60.60",
527+
},
528+
},
529+
expErr: apierrors.NewTooManyRequestsError(
530+
apierrors.ErrorCodeOverRequestRateLimit,
531+
"Request rate limit reached",
532+
),
533+
},
534+
}
535+
536+
// This test uses the SBFF middleware to inject the Sb-Forwarded-For IP address value, then
537+
// wraps a handler that calls performRateLimiting and stores the error value.
538+
for _, tc := range testCases {
539+
lmt := tollbooth.NewLimiter(
540+
1,
541+
&limiter.ExpirableOptions{
542+
DefaultExpirationTTL: time.Hour,
543+
},
544+
)
545+
546+
var obsErr error
547+
548+
var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) {
549+
obsErr = ts.API.performRateLimiting(lmt, r)
550+
}
551+
552+
errCallback := func(r *http.Request, err error) {
553+
}
554+
555+
middleware := sbff.Middleware(&ts.Config.Security, errCallback)
556+
557+
wrappedHandler := middleware(handler)
558+
559+
for _, h := range tc.headerValues {
560+
r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil)
561+
562+
if h.rateLimiting != "" {
563+
r.Header.Set(ts.Config.RateLimitHeader, h.rateLimiting)
564+
}
565+
566+
if h.sbForwardedFor != "" {
567+
r.Header.Set(sbff.HeaderName, h.sbForwardedFor)
568+
}
569+
570+
wrappedHandler.ServeHTTP(nil, r)
571+
}
572+
573+
require.ErrorIs(ts.T(), obsErr, tc.expErr)
574+
}
575+
576+
}
577+
578+
func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithHeader() {
419579
ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting"
420580

421581
tests := []struct {

internal/conf/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ type SecurityConfiguration struct {
731731
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
732732
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
733733
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
734+
SbForwardedForEnabled bool `json:"sb_forwarded_for_enabled" split_words:"true" default:"false"`
734735

735736
DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"`
736737
}

internal/sbff/sbff.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package sbff
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"net/http"
8+
"strings"
9+
10+
"github.com/supabase/auth/internal/conf"
11+
)
12+
13+
// HeaderName is the Sb-Forwarded-For header name. It is all lowercase here as HTTP header names
14+
// are not case-sensitive.
15+
const HeaderName = "sb-forwarded-for"
16+
17+
var (
18+
ctxKeySBFF = &struct{}{}
19+
20+
ErrHeaderNotFound = errors.New("Sb-Forwarded-For header not found")
21+
ErrHeaderInvalid = errors.New("invalid Sb-Forwarded-For header value")
22+
)
23+
24+
func parseSBFFHeader(headerVal string) (string, error) {
25+
values := strings.SplitN(headerVal, ",", 2)
26+
key := strings.TrimSpace(values[0])
27+
if ipAddr := net.ParseIP(key); ipAddr != nil {
28+
return ipAddr.String(), nil
29+
}
30+
31+
return "", ErrHeaderInvalid
32+
}
33+
34+
// GetIPAddress returns the value of the IP address in Sb-Forwarded-For as defined by
35+
// SBForwardedForMiddleware. If no value is present in the request context, this function will
36+
// return ("", false).
37+
func GetIPAddress(r *http.Request) (addr string, found bool) {
38+
if ipAddr, ok := r.Context().Value(ctxKeySBFF).(string); ok && ipAddr != "" {
39+
return ipAddr, true
40+
}
41+
42+
return "", false
43+
}
44+
45+
// withIPAddress parses the Sb-Forwarded-For header and adds the leftmost value to the
46+
// request context if it is a valid IP address, then returns a new request with modified context.
47+
// If the leftmost value is not a valid IP address or the header is not set, this function returns
48+
// an error.
49+
func withIPAddress(r *http.Request) (*http.Request, error) {
50+
headerVal := r.Header.Get(HeaderName)
51+
if headerVal == "" {
52+
return nil, ErrHeaderNotFound
53+
}
54+
55+
parsedIPAddr, err := parseSBFFHeader(headerVal)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
ctx := r.Context()
61+
newCtx := context.WithValue(ctx, ctxKeySBFF, parsedIPAddr)
62+
out := r.WithContext(newCtx)
63+
64+
return out, nil
65+
}
66+
67+
// Middleware returns a middleware function that parses the Sb-Forwarded-For header
68+
// and adds the leftmost header value to the request context if GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED
69+
// is true and the value is a valid IP address.
70+
func Middleware(cfg *conf.SecurityConfiguration, errCallback func(*http.Request, error)) func(http.Handler) http.Handler {
71+
out := func(next http.Handler) http.Handler {
72+
handlerFunc := func(rw http.ResponseWriter, r *http.Request) {
73+
if !cfg.SbForwardedForEnabled {
74+
next.ServeHTTP(rw, r)
75+
return
76+
}
77+
78+
reqWithSBFF, err := withIPAddress(r)
79+
switch {
80+
case err == nil:
81+
next.ServeHTTP(rw, reqWithSBFF)
82+
case errors.Is(err, ErrHeaderNotFound):
83+
next.ServeHTTP(rw, r)
84+
default:
85+
errCallback(r, err)
86+
next.ServeHTTP(rw, r)
87+
}
88+
}
89+
90+
return http.HandlerFunc(handlerFunc)
91+
}
92+
93+
return out
94+
}

0 commit comments

Comments
 (0)