Skip to content

Commit 1f06f58

Browse files
cstocktonChris Stockton
and
Chris Stockton
authored
feat: add an optional burstable rate limiter (#1924)
The existing rate limiter was moved to a separate package and renamed to IntervalLimiter. Added BurstLimiter which is a wrapper around the "golang.org/x/time/rate" package. The conf.Rate type now has a private `typ` field that indicates if it is a `"interval"` or `"burst"` rate limiter. If the config value is in the form of `"<burst>/<rate>"` we set it to `"burst"`, otherwise `"interval"`. The `conf.Rate.GetRateType()` method is then called from the `ratelimit.New` function to determine the underlying type of `ratelimit.Limiter` it returns. Then changed `api.NewLimiterOptions` to call `ratelimit.New` instead of creating a specific type of rate limiter. --------- Co-authored-by: Chris Stockton <[email protected]>
1 parent 50eb69b commit 1f06f58

12 files changed

+578
-199
lines changed

internal/api/options.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ import (
66
"github.com/didip/tollbooth/v5"
77
"github.com/didip/tollbooth/v5/limiter"
88
"github.com/supabase/auth/internal/conf"
9+
"github.com/supabase/auth/internal/ratelimit"
910
)
1011

1112
type Option interface {
1213
apply(*API)
1314
}
1415

1516
type LimiterOptions struct {
16-
Email *RateLimiter
17-
Phone *RateLimiter
17+
Email ratelimit.Limiter
18+
Phone ratelimit.Limiter
1819

1920
Signups *limiter.Limiter
2021
AnonymousSignIns *limiter.Limiter
@@ -36,8 +37,9 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }
3637
func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
3738
o := &LimiterOptions{}
3839

39-
o.Email = newRateLimiter(gc.RateLimitEmailSent)
40-
o.Phone = newRateLimiter(gc.RateLimitSmsSent)
40+
o.Email = ratelimit.New(gc.RateLimitEmailSent)
41+
o.Phone = ratelimit.New(gc.RateLimitSmsSent)
42+
4143
o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
4244
&limiter.ExpirableOptions{
4345
DefaultExpirationTTL: time.Hour,

internal/api/options_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package api
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/supabase/auth/internal/conf"
8+
)
9+
10+
func TestNewLimiterOptions(t *testing.T) {
11+
cfg := &conf.GlobalConfiguration{}
12+
cfg.ApplyDefaults()
13+
14+
rl := NewLimiterOptions(cfg)
15+
assert.NotNil(t, rl.Email)
16+
assert.NotNil(t, rl.Phone)
17+
assert.NotNil(t, rl.Signups)
18+
assert.NotNil(t, rl.AnonymousSignIns)
19+
assert.NotNil(t, rl.Recover)
20+
assert.NotNil(t, rl.Resend)
21+
assert.NotNil(t, rl.MagicLink)
22+
assert.NotNil(t, rl.Otp)
23+
assert.NotNil(t, rl.Token)
24+
assert.NotNil(t, rl.Verify)
25+
assert.NotNil(t, rl.User)
26+
assert.NotNil(t, rl.FactorVerify)
27+
assert.NotNil(t, rl.FactorChallenge)
28+
assert.NotNil(t, rl.SSO)
29+
assert.NotNil(t, rl.SAMLAssertion)
30+
}

internal/api/ratelimits.go

-49
This file was deleted.

internal/api/ratelimits_test.go

-125
This file was deleted.

internal/conf/rate.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,28 @@ import (
99

1010
const defaultOverTime = time.Hour
1111

12+
const (
13+
BurstRateType = "burst"
14+
IntervalRateType = "interval"
15+
)
16+
1217
type Rate struct {
1318
Events float64 `json:"events,omitempty"`
1419
OverTime time.Duration `json:"over_time,omitempty"`
20+
typ string
1521
}
1622

17-
func (r *Rate) EventsPerSecond() float64 {
18-
d := r.OverTime
19-
if d == 0 {
20-
d = defaultOverTime
23+
func (r *Rate) GetRateType() string {
24+
if r.typ == "" {
25+
return IntervalRateType
2126
}
22-
return r.Events / d.Seconds()
27+
return r.typ
2328
}
2429

2530
// Decode is used by envconfig to parse the env-config string to a Rate value.
2631
func (r *Rate) Decode(value string) error {
2732
if f, err := strconv.ParseFloat(value, 64); err == nil {
33+
r.typ = IntervalRateType
2834
r.Events = f
2935
r.OverTime = defaultOverTime
3036
return nil
@@ -45,6 +51,7 @@ func (r *Rate) Decode(value string) error {
4551
return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err)
4652
}
4753

54+
r.typ = BurstRateType
4855
r.Events = float64(e)
4956
r.OverTime = d
5057
return nil

internal/conf/rate_test.go

+28-16
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,39 @@ import (
1010
func TestRateDecode(t *testing.T) {
1111
cases := []struct {
1212
str string
13-
eps float64
1413
exp Rate
1514
err string
1615
}{
17-
{str: "1800", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}},
18-
{str: "1800.0", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}},
19-
{str: "3600/1h", eps: 1, exp: Rate{Events: 3600, OverTime: time.Hour}},
16+
{str: "1800",
17+
exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}},
18+
{str: "1800.0",
19+
exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}},
20+
{str: "3600/1h",
21+
exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}},
22+
{str: "3600/1h0m0s",
23+
exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}},
2024
{str: "100/24h",
21-
eps: 0.0011574074074074073,
22-
exp: Rate{Events: 100, OverTime: time.Hour * 24}},
23-
{str: "", eps: 1, exp: Rate{},
25+
exp: Rate{Events: 100, OverTime: time.Hour * 24, typ: BurstRateType}},
26+
{str: "", exp: Rate{},
2427
err: `rate: value does not match`},
25-
{str: "1h", eps: 1, exp: Rate{},
28+
{str: "1h", exp: Rate{},
2629
err: `rate: value does not match`},
27-
{str: "/", eps: 1, exp: Rate{},
30+
{str: "/", exp: Rate{},
2831
err: `rate: events part of rate value`},
29-
{str: "/1h", eps: 1, exp: Rate{},
32+
{str: "/1h", exp: Rate{},
3033
err: `rate: events part of rate value`},
31-
{str: "3600.0/1h", eps: 1, exp: Rate{},
34+
{str: "3600.0/1h", exp: Rate{},
3235
err: `rate: events part of rate value "3600.0/1h" failed to parse`},
33-
{str: "100/", eps: 1, exp: Rate{},
36+
{str: "100/", exp: Rate{},
3437
err: `rate: over-time part of rate value`},
35-
{str: "100/1", eps: 1, exp: Rate{},
38+
{str: "100/1", exp: Rate{},
3639
err: `rate: over-time part of rate value`},
3740

3841
// zero events
39-
{str: "0/1h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour}},
40-
{str: "0/24h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour * 24}},
42+
{str: "0/1h",
43+
exp: Rate{Events: 0, OverTime: time.Hour, typ: BurstRateType}},
44+
{str: "0/24h",
45+
exp: Rate{Events: 0, OverTime: time.Hour * 24, typ: BurstRateType}},
4146
}
4247
for idx, tc := range cases {
4348
var r Rate
@@ -51,6 +56,13 @@ func TestRateDecode(t *testing.T) {
5156
}
5257
require.NoError(t, err)
5358
require.Equal(t, tc.exp, r)
54-
require.Equal(t, tc.eps, r.EventsPerSecond())
59+
require.Equal(t, tc.exp.typ, r.GetRateType())
5560
}
61+
62+
// GetRateType() zero value
63+
require.Equal(t, IntervalRateType, (&Rate{}).GetRateType())
64+
65+
// String()
66+
require.Equal(t, "0.000000", (&Rate{}).String())
67+
require.Equal(t, "100/1h0m0s", (&Rate{Events: 100, OverTime: time.Hour}).String())
5668
}

0 commit comments

Comments
 (0)