Skip to content

Commit 26d4dfc

Browse files
author
bloom
committed
server: add rate limits
1 parent 4a42d64 commit 26d4dfc

File tree

6 files changed

+356
-0
lines changed

6 files changed

+356
-0
lines changed

pkg/ratelimit/ratelimit.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Package ratelimit provides a simple fixed-window rate limiter.
2+
package ratelimit
3+
4+
import (
5+
"encoding/binary"
6+
"sync"
7+
"time"
8+
9+
"github.com/skerkour/stdx-go/crypto/blake3"
10+
)
11+
12+
// Limiter tracks request counts within time buckets.
13+
type Limiter struct {
14+
mutex sync.Mutex
15+
buckets map[[32]byte]*bucket
16+
stop chan struct{}
17+
}
18+
19+
type bucket struct {
20+
count uint64
21+
expires time.Time
22+
}
23+
24+
// New creates a new rate limiter with automatic cleanup of expired buckets.
25+
func New() *Limiter {
26+
l := &Limiter{
27+
mutex: sync.Mutex{},
28+
buckets: make(map[[32]byte]*bucket),
29+
stop: make(chan struct{}),
30+
}
31+
go l.cleanupLoop()
32+
return l
33+
}
34+
35+
// RateLimit checks if an action by an actor is allowed within the rate limit.
36+
// It returns true if the action is allowed, false if rate limited.
37+
//
38+
// Parameters:
39+
// - action: identifies the type of action being rate limited (e.g., "login", "api-call")
40+
// - actor: identifies who is performing the action (e.g., user ID, IP address)
41+
// - timeBucket: the duration of each rate limit window
42+
// - allowed: maximum number of actions allowed per time bucket
43+
func (l *Limiter) RateLimit(action string, actor []byte, timeBucket time.Duration, allowed uint64) bool {
44+
now := time.Now()
45+
bucketStart := now.Truncate(timeBucket)
46+
key := makeKey(action, actor, bucketStart)
47+
48+
l.mutex.Lock()
49+
defer l.mutex.Unlock()
50+
51+
b, exists := l.buckets[key]
52+
if !exists {
53+
l.buckets[key] = &bucket{
54+
count: 1,
55+
expires: bucketStart.Add(timeBucket * 2), // Keep for one extra period for safety
56+
}
57+
return true
58+
}
59+
60+
if b.count >= allowed {
61+
return false
62+
}
63+
64+
b.count++
65+
return true
66+
}
67+
68+
// Count returns the current count for an action/actor in the current time bucket.
69+
// Useful for showing users how many requests they have remaining.
70+
func (l *Limiter) Count(action string, actor []byte, timeBucket time.Duration) uint64 {
71+
now := time.Now()
72+
bucketStart := now.Truncate(timeBucket)
73+
key := makeKey(action, actor, bucketStart)
74+
75+
l.mutex.Lock()
76+
defer l.mutex.Unlock()
77+
78+
if b, exists := l.buckets[key]; exists {
79+
return b.count
80+
}
81+
return 0
82+
}
83+
84+
// Remaining returns how many requests are remaining for an action/actor.
85+
func (l *Limiter) Remaining(action string, actor []byte, timeBucket time.Duration, allowed uint64) uint64 {
86+
count := l.Count(action, actor, timeBucket)
87+
if count >= allowed {
88+
return 0
89+
}
90+
return allowed - count
91+
}
92+
93+
// Stop stops the background cleanup goroutine.
94+
// Call this when the limiter is no longer needed.
95+
func (l *Limiter) Stop() {
96+
close(l.stop)
97+
}
98+
99+
func (l *Limiter) cleanupLoop() {
100+
ticker := time.NewTicker(time.Minute)
101+
defer ticker.Stop()
102+
103+
for {
104+
select {
105+
case <-ticker.C:
106+
l.cleanup()
107+
case <-l.stop:
108+
return
109+
}
110+
}
111+
}
112+
113+
func (l *Limiter) cleanup() {
114+
now := time.Now()
115+
116+
l.mutex.Lock()
117+
defer l.mutex.Unlock()
118+
119+
for key, b := range l.buckets {
120+
if now.After(b.expires) {
121+
delete(l.buckets, key)
122+
}
123+
}
124+
}
125+
126+
func makeKey(action string, actor []byte, bucketStart time.Time) [32]byte {
127+
var hash [32]byte
128+
hasher := blake3.New(32, nil)
129+
hasher.Write([]byte(action))
130+
hasher.Write(actor)
131+
binary.Write(hasher, binary.LittleEndian, bucketStart.UnixNano())
132+
133+
hasher.Sum(hash[:0])
134+
return hash
135+
}

pkg/ratelimit/ratelimit_test.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package ratelimit
2+
3+
import (
4+
"sync"
5+
"testing"
6+
"time"
7+
)
8+
9+
func TestRateLimit_Basic(t *testing.T) {
10+
l := New()
11+
defer l.Stop()
12+
13+
actor := []byte("user123")
14+
action := "test-action"
15+
bucket := time.Minute
16+
allowed := uint64(3)
17+
18+
// First 3 should be allowed
19+
for i := 0; i < 3; i++ {
20+
if !l.RateLimit(action, actor, bucket, allowed) {
21+
t.Errorf("Request %d should have been allowed", i+1)
22+
}
23+
}
24+
25+
// 4th should be rate limited
26+
if l.RateLimit(action, actor, bucket, allowed) {
27+
t.Error("Request 4 should have been rate limited")
28+
}
29+
30+
// 5th should also be rate limited
31+
if l.RateLimit(action, actor, bucket, allowed) {
32+
t.Error("Request 5 should have been rate limited")
33+
}
34+
}
35+
36+
func TestRateLimit_DifferentActors(t *testing.T) {
37+
l := New()
38+
defer l.Stop()
39+
40+
actor1 := []byte("user1")
41+
actor2 := []byte("user2")
42+
action := "test-action"
43+
bucket := time.Minute
44+
allowed := uint64(2)
45+
46+
// Actor 1 uses both requests
47+
l.RateLimit(action, actor1, bucket, allowed)
48+
l.RateLimit(action, actor1, bucket, allowed)
49+
50+
// Actor 1 should be limited
51+
if l.RateLimit(action, actor1, bucket, allowed) {
52+
t.Error("Actor1 should be rate limited")
53+
}
54+
55+
// Actor 2 should still be allowed
56+
if !l.RateLimit(action, actor2, bucket, allowed) {
57+
t.Error("Actor2 should be allowed")
58+
}
59+
}
60+
61+
func TestRateLimit_DifferentActions(t *testing.T) {
62+
l := New()
63+
defer l.Stop()
64+
65+
actor := []byte("user123")
66+
action1 := "action1"
67+
action2 := "action2"
68+
bucket := time.Minute
69+
allowed := uint64(1)
70+
71+
// Use up action1 limit
72+
l.RateLimit(action1, actor, bucket, allowed)
73+
74+
// Action1 should be limited
75+
if l.RateLimit(action1, actor, bucket, allowed) {
76+
t.Error("Action1 should be rate limited")
77+
}
78+
79+
// Action2 should still be allowed
80+
if !l.RateLimit(action2, actor, bucket, allowed) {
81+
t.Error("Action2 should be allowed")
82+
}
83+
}
84+
85+
func TestRateLimit_BucketReset(t *testing.T) {
86+
l := New()
87+
defer l.Stop()
88+
89+
actor := []byte("user123")
90+
action := "test-action"
91+
bucket := 100 * time.Millisecond
92+
allowed := uint64(2)
93+
94+
// Use up the limit
95+
l.RateLimit(action, actor, bucket, allowed)
96+
l.RateLimit(action, actor, bucket, allowed)
97+
98+
if l.RateLimit(action, actor, bucket, allowed) {
99+
t.Error("Should be rate limited")
100+
}
101+
102+
// Wait for next bucket
103+
time.Sleep(150 * time.Millisecond)
104+
105+
// Should be allowed again
106+
if !l.RateLimit(action, actor, bucket, allowed) {
107+
t.Error("Should be allowed after bucket reset")
108+
}
109+
}
110+
111+
func TestCount(t *testing.T) {
112+
l := New()
113+
defer l.Stop()
114+
115+
actor := []byte("user123")
116+
action := "test-action"
117+
bucket := time.Minute
118+
119+
if c := l.Count(action, actor, bucket); c != 0 {
120+
t.Errorf("Expected count 0, got %d", c)
121+
}
122+
123+
l.RateLimit(action, actor, bucket, 10)
124+
l.RateLimit(action, actor, bucket, 10)
125+
126+
if c := l.Count(action, actor, bucket); c != 2 {
127+
t.Errorf("Expected count 2, got %d", c)
128+
}
129+
}
130+
131+
func TestRemaining(t *testing.T) {
132+
l := New()
133+
defer l.Stop()
134+
135+
actor := []byte("user123")
136+
action := "test-action"
137+
bucket := time.Minute
138+
allowed := uint64(5)
139+
140+
if r := l.Remaining(action, actor, bucket, allowed); r != 5 {
141+
t.Errorf("Expected 5 remaining, got %d", r)
142+
}
143+
144+
l.RateLimit(action, actor, bucket, allowed)
145+
l.RateLimit(action, actor, bucket, allowed)
146+
147+
if r := l.Remaining(action, actor, bucket, allowed); r != 3 {
148+
t.Errorf("Expected 3 remaining, got %d", r)
149+
}
150+
}
151+
152+
func TestRateLimit_Concurrent(t *testing.T) {
153+
l := New()
154+
defer l.Stop()
155+
156+
actor := []byte("user123")
157+
action := "test-action"
158+
bucket := time.Minute
159+
allowed := uint64(100)
160+
161+
var wg sync.WaitGroup
162+
allowedCount := 0
163+
var mu sync.Mutex
164+
165+
// Run 200 concurrent requests
166+
for i := 0; i < 200; i++ {
167+
wg.Add(1)
168+
go func() {
169+
defer wg.Done()
170+
if l.RateLimit(action, actor, bucket, allowed) {
171+
mu.Lock()
172+
allowedCount++
173+
mu.Unlock()
174+
}
175+
}()
176+
}
177+
178+
wg.Wait()
179+
180+
// Exactly 100 should have been allowed
181+
if allowedCount != 100 {
182+
t.Errorf("Expected 100 allowed, got %d", allowedCount)
183+
}
184+
}
185+
186+
func TestRateLimit_BinaryActor(t *testing.T) {
187+
l := New()
188+
defer l.Stop()
189+
190+
// Test with binary data containing special characters
191+
actor := []byte{0x00, 0x01, 0xFF, ':', '\n'}
192+
action := "test-action"
193+
bucket := time.Minute
194+
allowed := uint64(1)
195+
196+
if !l.RateLimit(action, actor, bucket, allowed) {
197+
t.Error("First request should be allowed")
198+
}
199+
if l.RateLimit(action, actor, bucket, allowed) {
200+
t.Error("Second request should be rate limited")
201+
}
202+
}

pkg/services/site/service/service.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/skerkour/stdx-go/queue"
1717
"markdown.ninja/cmd/mdninja-server/config"
1818
"markdown.ninja/pkg/mailer"
19+
"markdown.ninja/pkg/ratelimit"
1920
"markdown.ninja/pkg/services/contacts"
2021
"markdown.ninja/pkg/services/content"
2122
"markdown.ninja/pkg/services/emails"
@@ -61,6 +62,8 @@ type SiteService struct {
6162
cacheZstdCompressor *zstd.Encoder
6263
cacheZstdDecompressor *zstd.Decoder
6364

65+
rateLimiter *ratelimit.Limiter
66+
6467
themes map[string]parsedTheme
6568
}
6669

@@ -177,6 +180,8 @@ func NewSiteService(conf config.Config, db db.DB, queue queue.Queue, mailer mail
177180
feedsCache: feedsCache,
178181
sitemapsCache: sitemapsCache,
179182

183+
rateLimiter: ratelimit.New(),
184+
180185
themes: themes,
181186
}
182187
return

pkg/services/site/service/subscribe.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package service
33
import (
44
"context"
55
"strings"
6+
"time"
67

78
"github.com/skerkour/stdx-go/countries"
89
"github.com/skerkour/stdx-go/crypto"
@@ -29,6 +30,11 @@ func (service *SiteService) Subscribe(ctx context.Context, input site.SubscribeI
2930
unverifiedContactAlreadyExists := false
3031
logger := slogx.FromCtx(ctx)
3132

33+
if !service.rateLimiter.RateLimit("SiteService.Subscribe", httpCtx.Client.IP.AsSlice(), time.Hour, 10) {
34+
err = errs.InvalidArgument("Too many requests. Please try again later.")
35+
return
36+
}
37+
3238
err = service.kernel.ValidateEmail(ctx, email, true)
3339
if err != nil {
3440
return

pkg/services/store/service/place_order.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ func (service *StoreService) PlaceOrder(ctx context.Context, input store.PlaceOr
3232
return
3333
}
3434

35+
if !service.rateLimiter.RateLimit("StoreService.PlaceOrder", httpCtx.Client.IP.AsSlice(), time.Hour, 10) {
36+
err = errs.InvalidArgument("Too many requests. Please try again later.")
37+
return
38+
}
39+
3540
customer := service.contactsService.CurrentContact(ctx)
3641
if customer == nil && input.Email == nil {
3742
err = errs.InvalidArgument("email is required")

0 commit comments

Comments
 (0)