Skip to content

Commit 4e7c632

Browse files
author
bloom
committed
improve rate limites
1 parent 21ab793 commit 4e7c632

File tree

9 files changed

+92
-63
lines changed

9 files changed

+92
-63
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ require (
4646
github.com/aws/smithy-go v1.24.0 // indirect
4747
github.com/aymerick/douceur v0.2.0 // indirect
4848
github.com/clipperhouse/stringish v0.1.1 // indirect
49-
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
49+
github.com/clipperhouse/uax29/v2 v2.3.1 // indirect
5050
github.com/dlclark/regexp2 v1.11.5 // indirect
5151
github.com/fsnotify/fsnotify v1.9.0 // indirect
5252
github.com/gorilla/css v1.0.1 // indirect

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuP
5050
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
5151
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
5252
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
53-
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
54-
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
53+
github.com/clipperhouse/uax29/v2 v2.3.1 h1:RjM8gnVbFbgI67SBekIC7ihFpyXwRPYWXn9BZActHbw=
54+
github.com/clipperhouse/uax29/v2 v2.3.1/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
5555
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5656
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
5757
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

pkg/errs/errors.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func IsInternal(err error) bool {
7272

7373
switch err.(type) {
7474
case *NotFoundError, *InvalidArgumentError,
75-
*PermissionDeniedError, *AuthenticationRequiredError, *AlreadyExistsError:
75+
*PermissionDeniedError, *AuthenticationRequiredError, *AlreadyExistsError, *TooManyRequestsError:
7676
return false
7777
default:
7878
return true
@@ -136,3 +136,17 @@ func (err *AuthenticationRequiredError) Error() string {
136136
func AuthenticationRequired(message string) *AuthenticationRequiredError {
137137
return &AuthenticationRequiredError{message: message}
138138
}
139+
140+
// TooManyRequestsError is a wrapper for an error when a client is sending too many requests
141+
type TooManyRequestsError struct {
142+
// message string
143+
}
144+
145+
func (err *TooManyRequestsError) Error() string {
146+
return "Too many requests. Please try again later."
147+
}
148+
149+
// TooManyRequests returns a new TooManyRequestsError
150+
func TooManyRequests() *TooManyRequestsError {
151+
return &TooManyRequestsError{}
152+
}

pkg/ratelimit/ratelimit.go

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@ package ratelimit
33

44
import (
55
"encoding/binary"
6+
"math/rand/v2"
67
"sync"
78
"time"
89

9-
"github.com/skerkour/stdx-go/crypto/blake3"
10+
"github.com/skerkour/stdx-go/xxh3"
1011
)
1112

1213
// Limiter tracks request counts within time buckets.
1314
type Limiter struct {
14-
mutex sync.Mutex
15-
buckets map[[32]byte]*bucket
16-
stop chan struct{}
15+
mutex sync.Mutex
16+
buckets map[uint64]*bucket
17+
stop chan struct{}
18+
hashSeed uint64
1719
}
1820

1921
type bucket struct {
@@ -23,32 +25,35 @@ type bucket struct {
2325

2426
// New creates a new rate limiter with automatic cleanup of expired buckets.
2527
func New() *Limiter {
28+
2629
limiter := &Limiter{
27-
mutex: sync.Mutex{},
28-
buckets: make(map[[32]byte]*bucket),
29-
stop: make(chan struct{}),
30+
mutex: sync.Mutex{},
31+
buckets: make(map[uint64]*bucket),
32+
stop: make(chan struct{}),
33+
hashSeed: rand.Uint64(),
3034
}
3135
go limiter.cleanupLoop()
3236
return limiter
3337
}
3438

35-
// RateLimit checks if an action by an actor is allowed within the rate limit.
39+
// IsAllowed checks if an action by an actor is allowed within the rate limit.
3640
// It returns true if the action is allowed, false if rate limited.
3741
//
3842
// Parameters:
43+
// - namespace: optional namespace for the action check (e.g. tenant ID). Can be nil.
3944
// - action: identifies the type of action being rate limited (e.g., "login", "api-call")
4045
// - actor: identifies who is performing the action (e.g., user ID, IP address)
4146
// - timeBucket: the duration of each rate limit window
4247
// - allowed: maximum number of actions allowed per time bucket
43-
func (limiter *Limiter) RateLimit(action string, actor []byte, timeBucket time.Duration, allowed uint64) bool {
48+
func (limiter *Limiter) IsAllowed(action string, namespace []byte, actor []byte, timeBucket time.Duration, allowed uint64) bool {
4449
now := time.Now()
4550
bucketStart := now.Truncate(timeBucket)
46-
key := makeKey(action, actor, uint64(bucketStart.UnixNano()), uint64(timeBucket.Nanoseconds()))
51+
key := limiter.makeKey(action, namespace, actor, uint64(bucketStart.UnixNano()), uint64(timeBucket.Nanoseconds()))
4752

4853
limiter.mutex.Lock()
4954
defer limiter.mutex.Unlock()
5055

51-
b, exists := limiter.buckets[key]
56+
existingBucket, exists := limiter.buckets[key]
5257
if !exists {
5358
limiter.buckets[key] = &bucket{
5459
count: 1,
@@ -57,20 +62,20 @@ func (limiter *Limiter) RateLimit(action string, actor []byte, timeBucket time.D
5762
return true
5863
}
5964

60-
if b.count >= allowed {
65+
if existingBucket.count >= allowed {
6166
return false
6267
}
6368

64-
b.count++
69+
existingBucket.count++
6570
return true
6671
}
6772

6873
// Count returns the current count for an action/actor in the current time bucket.
6974
// Useful for showing users how many requests they have remaining.
70-
func (limiter *Limiter) Count(action string, actor []byte, timeBucket time.Duration) uint64 {
75+
func (limiter *Limiter) Count(action string, namespace []byte, actor []byte, timeBucket time.Duration) uint64 {
7176
now := time.Now()
7277
bucketStart := now.Truncate(timeBucket)
73-
key := makeKey(action, actor, uint64(bucketStart.UnixNano()), uint64(timeBucket.Nanoseconds()))
78+
key := limiter.makeKey(action, namespace, actor, uint64(bucketStart.UnixNano()), uint64(timeBucket.Nanoseconds()))
7479

7580
limiter.mutex.Lock()
7681
defer limiter.mutex.Unlock()
@@ -82,8 +87,8 @@ func (limiter *Limiter) Count(action string, actor []byte, timeBucket time.Durat
8287
}
8388

8489
// Remaining returns how many requests are remaining for an action/actor.
85-
func (limiter *Limiter) Remaining(action string, actor []byte, timeBucket time.Duration, allowed uint64) uint64 {
86-
count := limiter.Count(action, actor, timeBucket)
90+
func (limiter *Limiter) Remaining(action string, namespace []byte, actor []byte, timeBucket time.Duration, allowed uint64) uint64 {
91+
count := limiter.Count(action, namespace, actor, timeBucket)
8792
if count >= allowed {
8893
return 0
8994
}
@@ -123,15 +128,14 @@ func (limiter *Limiter) cleanup() {
123128
}
124129
}
125130

126-
func makeKey(action string, actor []byte, bucketStartNanos uint64, timeBucketNanos uint64) [32]byte {
127-
var hash [32]byte
128-
129-
hasher := blake3.New(32, nil)
131+
// makeKey returns a stable key by hashing the inputs. It currently uses xxh3.
132+
func (limiter *Limiter) makeKey(action string, namespace []byte, actor []byte, bucketStartNanos uint64, timeBucketNanos uint64) uint64 {
133+
hasher := xxh3.NewSeed(limiter.hashSeed)
130134
hasher.Write([]byte(action))
135+
hasher.Write(namespace)
131136
hasher.Write(actor)
132137
binary.Write(hasher, binary.LittleEndian, bucketStartNanos)
133138
binary.Write(hasher, binary.LittleEndian, timeBucketNanos)
134139

135-
hasher.Sum(hash[:0])
136-
return hash
140+
return hasher.Sum64()
137141
}

pkg/ratelimit/ratelimit_test.go

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ func TestRateLimit_Basic(t *testing.T) {
1717

1818
// First 3 should be allowed
1919
for i := 0; i < 3; i++ {
20-
if !l.RateLimit(action, actor, bucket, allowed) {
20+
if !l.IsAllowed(action, nil, actor, bucket, allowed) {
2121
t.Errorf("Request %d should have been allowed", i+1)
2222
}
2323
}
2424

2525
// 4th should be rate limited
26-
if l.RateLimit(action, actor, bucket, allowed) {
26+
if l.IsAllowed(action, nil, actor, bucket, allowed) {
2727
t.Error("Request 4 should have been rate limited")
2828
}
2929

3030
// 5th should also be rate limited
31-
if l.RateLimit(action, actor, bucket, allowed) {
31+
if l.IsAllowed(action, nil, actor, bucket, allowed) {
3232
t.Error("Request 5 should have been rate limited")
3333
}
3434
}
@@ -37,23 +37,24 @@ func TestRateLimit_DifferentActors(t *testing.T) {
3737
l := New()
3838
defer l.Stop()
3939

40+
namespace := []byte("namespace1")
4041
actor1 := []byte("user1")
4142
actor2 := []byte("user2")
4243
action := "test-action"
4344
bucket := time.Minute
4445
allowed := uint64(2)
4546

4647
// Actor 1 uses both requests
47-
l.RateLimit(action, actor1, bucket, allowed)
48-
l.RateLimit(action, actor1, bucket, allowed)
48+
l.IsAllowed(action, namespace, actor1, bucket, allowed)
49+
l.IsAllowed(action, namespace, actor1, bucket, allowed)
4950

5051
// Actor 1 should be limited
51-
if l.RateLimit(action, actor1, bucket, allowed) {
52+
if l.IsAllowed(action, namespace, actor1, bucket, allowed) {
5253
t.Error("Actor1 should be rate limited")
5354
}
5455

5556
// Actor 2 should still be allowed
56-
if !l.RateLimit(action, actor2, bucket, allowed) {
57+
if !l.IsAllowed(action, namespace, actor2, bucket, allowed) {
5758
t.Error("Actor2 should be allowed")
5859
}
5960
}
@@ -63,21 +64,22 @@ func TestRateLimit_DifferentActions(t *testing.T) {
6364
defer l.Stop()
6465

6566
actor := []byte("user123")
67+
namespace := []byte("namespace")
6668
action1 := "action1"
6769
action2 := "action2"
6870
bucket := time.Minute
6971
allowed := uint64(1)
7072

7173
// Use up action1 limit
72-
l.RateLimit(action1, actor, bucket, allowed)
74+
l.IsAllowed(action1, namespace, actor, bucket, allowed)
7375

7476
// Action1 should be limited
75-
if l.RateLimit(action1, actor, bucket, allowed) {
77+
if l.IsAllowed(action1, namespace, actor, bucket, allowed) {
7678
t.Error("Action1 should be rate limited")
7779
}
7880

7981
// Action2 should still be allowed
80-
if !l.RateLimit(action2, actor, bucket, allowed) {
82+
if !l.IsAllowed(action2, namespace, actor, bucket, allowed) {
8183
t.Error("Action2 should be allowed")
8284
}
8385
}
@@ -86,24 +88,25 @@ func TestRateLimit_BucketReset(t *testing.T) {
8688
l := New()
8789
defer l.Stop()
8890

91+
namespace := []byte("namespace")
8992
actor := []byte("user123")
9093
action := "test-action"
9194
bucket := 100 * time.Millisecond
9295
allowed := uint64(2)
9396

9497
// Use up the limit
95-
l.RateLimit(action, actor, bucket, allowed)
96-
l.RateLimit(action, actor, bucket, allowed)
98+
l.IsAllowed(action, namespace, actor, bucket, allowed)
99+
l.IsAllowed(action, namespace, actor, bucket, allowed)
97100

98-
if l.RateLimit(action, actor, bucket, allowed) {
101+
if l.IsAllowed(action, namespace, actor, bucket, allowed) {
99102
t.Error("Should be rate limited")
100103
}
101104

102105
// Wait for next bucket
103106
time.Sleep(150 * time.Millisecond)
104107

105108
// Should be allowed again
106-
if !l.RateLimit(action, actor, bucket, allowed) {
109+
if !l.IsAllowed(action, namespace, actor, bucket, allowed) {
107110
t.Error("Should be allowed after bucket reset")
108111
}
109112
}
@@ -112,18 +115,19 @@ func TestCount(t *testing.T) {
112115
l := New()
113116
defer l.Stop()
114117

118+
namespace := []byte("namespace")
115119
actor := []byte("user123")
116120
action := "test-action"
117121
bucket := time.Minute
118122

119-
if c := l.Count(action, actor, bucket); c != 0 {
123+
if c := l.Count(action, namespace, actor, bucket); c != 0 {
120124
t.Errorf("Expected count 0, got %d", c)
121125
}
122126

123-
l.RateLimit(action, actor, bucket, 10)
124-
l.RateLimit(action, actor, bucket, 10)
127+
l.IsAllowed(action, namespace, actor, bucket, 10)
128+
l.IsAllowed(action, namespace, actor, bucket, 10)
125129

126-
if c := l.Count(action, actor, bucket); c != 2 {
130+
if c := l.Count(action, namespace, actor, bucket); c != 2 {
127131
t.Errorf("Expected count 2, got %d", c)
128132
}
129133
}
@@ -132,19 +136,20 @@ func TestRemaining(t *testing.T) {
132136
l := New()
133137
defer l.Stop()
134138

139+
namespace := []byte("namespace")
135140
actor := []byte("user123")
136141
action := "test-action"
137142
bucket := time.Minute
138143
allowed := uint64(5)
139144

140-
if r := l.Remaining(action, actor, bucket, allowed); r != 5 {
145+
if r := l.Remaining(action, namespace, actor, bucket, allowed); r != 5 {
141146
t.Errorf("Expected 5 remaining, got %d", r)
142147
}
143148

144-
l.RateLimit(action, actor, bucket, allowed)
145-
l.RateLimit(action, actor, bucket, allowed)
149+
l.IsAllowed(action, namespace, actor, bucket, allowed)
150+
l.IsAllowed(action, namespace, actor, bucket, allowed)
146151

147-
if r := l.Remaining(action, actor, bucket, allowed); r != 3 {
152+
if r := l.Remaining(action, namespace, actor, bucket, allowed); r != 3 {
148153
t.Errorf("Expected 3 remaining, got %d", r)
149154
}
150155
}
@@ -153,6 +158,7 @@ func TestRateLimit_Concurrent(t *testing.T) {
153158
l := New()
154159
defer l.Stop()
155160

161+
namespace := []byte("namespace")
156162
actor := []byte("user123")
157163
action := "test-action"
158164
bucket := time.Minute
@@ -162,12 +168,12 @@ func TestRateLimit_Concurrent(t *testing.T) {
162168
allowedCount := 0
163169
var mu sync.Mutex
164170

165-
// Run 200 concurrent requests
166-
for i := 0; i < 200; i++ {
171+
// Run 300 concurrent requests
172+
for i := 0; i < 300; i++ {
167173
wg.Add(1)
168174
go func() {
169175
defer wg.Done()
170-
if l.RateLimit(action, actor, bucket, allowed) {
176+
if l.IsAllowed(action, namespace, actor, bucket, allowed) {
171177
mu.Lock()
172178
allowedCount++
173179
mu.Unlock()
@@ -192,11 +198,12 @@ func TestRateLimit_BinaryActor(t *testing.T) {
192198
action := "test-action"
193199
bucket := time.Minute
194200
allowed := uint64(1)
201+
namespace := []byte("namespace")
195202

196-
if !l.RateLimit(action, actor, bucket, allowed) {
203+
if !l.IsAllowed(action, namespace, actor, bucket, allowed) {
197204
t.Error("First request should be allowed")
198205
}
199-
if l.RateLimit(action, actor, bucket, allowed) {
206+
if l.IsAllowed(action, namespace, actor, bucket, allowed) {
200207
t.Error("Second request should be rate limited")
201208
}
202209
}

pkg/server/apiutil/apiutil.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const (
3434
ErrorCodeInternal apiErrorCode = "INTERNAL"
3535
ErrorCodePermissionDenied apiErrorCode = "PERMISSION_DENIED"
3636
ErrorCodeAuthenticationRequired apiErrorCode = "AUTHENTICATION_REQUIRED"
37+
ErrorCodeTooManyRequests apiErrorCode = "TOO_MANY_REQUESTS"
3738
)
3839

3940
// var jsonEncodingOptions = json.JoinOptions(json.FormatNilMapAsNull(true))
@@ -135,6 +136,9 @@ func SendError(ctx context.Context, w http.ResponseWriter, err error) {
135136
case *errs.AuthenticationRequiredError:
136137
code = ErrorCodeAuthenticationRequired
137138
statusCode = http.StatusUnauthorized
139+
case *errs.TooManyRequestsError:
140+
code = ErrorCodeTooManyRequests
141+
statusCode = http.StatusTooManyRequests
138142
default:
139143
code = ErrorCodeInternal
140144
statusCode = http.StatusInternalServerError

pkg/services/site/service/login.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ func (service *SiteService) Login(ctx context.Context, input site.LoginInput) (r
3434
return
3535
}
3636

37-
if !service.rateLimiter.RateLimit("SiteService.Login", httpCtx.Client.IP.AsSlice(), time.Hour, 20) {
38-
err = errs.InvalidArgument("Too many requests. Please try again later.")
37+
if !service.rateLimiter.IsAllowed("SiteService.Login", website.ID.Bytes(), httpCtx.Client.IP.AsSlice(), time.Hour, 30) {
38+
err = errs.TooManyRequests()
3939
return
4040
}
4141

0 commit comments

Comments
 (0)