Skip to content

Commit f61c2ea

Browse files
Implement rate limiting based on IP+UA (#211)
* Implement rate limiting * Use separate Redis instance for rate limiting * Fix flag name * Remove annotations * Set rate-limit-max to 40 * Update comments. * Update comments and benchmark * Reuse the same tooManyRequests error * Fix indent in cloudbuild.yaml * Add RATE_LIMIT_REDIS_ADDRESS to the mlab-ns template too * Do not set an error status code on rate limit yet * Address review comments * Update comment * Address review comments
1 parent d2396c8 commit f61c2ea

File tree

13 files changed

+498
-8
lines changed

13 files changed

+498
-8
lines changed

cloudbuild/app.yaml.mlab-ns.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@ env_variables:
3737
LOCATOR_MAXMIND: true
3838
MAXMIND_URL: gs://downloader-{{PLATFORM_PROJECT}}/Maxmind/current/GeoLite2-City.tar.gz
3939
REDIS_ADDRESS: {{REDIS_ADDRESS}}
40+
RATE_LIMIT_REDIS_ADDRESS: {{RATE_LIMIT_REDIS_ADDRESS}}
4041
PROMETHEUSX_LISTEN_ADDRESS: ':9090' # Must match one of the forwarded_ports above.
4142
PROMETHEUS_URL: 'https://prometheus-basicauth.{{PLATFORM_PROJECT}}.measurementlab.net/'

cloudbuild/app.yaml.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ env_variables:
3636
LOCATOR_MAXMIND: true
3737
MAXMIND_URL: gs://downloader-{{PLATFORM_PROJECT}}/Maxmind/current/GeoLite2-City.tar.gz
3838
REDIS_ADDRESS: {{REDIS_ADDRESS}}
39+
RATE_LIMIT_REDIS_ADDRESS: {{RATE_LIMIT_REDIS_ADDRESS}}
3940
PROMETHEUSX_LISTEN_ADDRESS: ':9090' # Must match one of the forwarded_ports above.
4041
PROMETHEUS_URL: 'https://prometheus-basicauth.{{PLATFORM_PROJECT}}.measurementlab.net/'

cloudbuild/cloudbuild.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ steps:
2929
-e 's/{{PROJECT}}/$PROJECT_ID/g'
3030
-e 's/{{PLATFORM_PROJECT}}/$_PLATFORM_PROJECT/'
3131
-e 's/{{REDIS_ADDRESS}}/$_REDIS_ADDRESS/'
32+
-e 's/{{RATE_LIMIT_REDIS_ADDRESS}}/$_RATE_LIMIT_REDIS_ADDRESS/'
3233
app.yaml
3334
- gcloud --project $PROJECT_ID app deploy --promote app.yaml
3435
# After deploying the new service, deploy the openapi spec.
@@ -47,6 +48,7 @@ steps:
4748
-e 's/{{PROJECT}}/$PROJECT_ID/g'
4849
-e 's/{{PLATFORM_PROJECT}}/$_PLATFORM_PROJECT/'
4950
-e 's/{{REDIS_ADDRESS}}/$_REDIS_ADDRESS/'
51+
-e 's/{{RATE_LIMIT_REDIS_ADDRESS}}/$_RATE_LIMIT_REDIS_ADDRESS/'
5052
app.yaml
5153
- gcloud --project $PROJECT_ID app deploy --promote app.yaml
5254
# After deploying the new service, deploy the openapi spec.

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.20
55
require (
66
cloud.google.com/go/compute/metadata v0.2.3
77
cloud.google.com/go/secretmanager v1.11.2
8+
github.com/alicebob/miniredis v2.5.0+incompatible
89
github.com/apex/log v1.9.0
910
github.com/cenkalti/backoff/v4 v4.1.3
1011
github.com/go-test/deep v1.0.8
@@ -33,12 +34,14 @@ require (
3334
)
3435

3536
require (
37+
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
3638
github.com/evanphx/json-patch v4.9.0+incompatible // indirect
3739
github.com/google/s2a-go v0.1.7 // indirect
3840
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
3941
github.com/imdario/mergo v0.3.5 // indirect
4042
github.com/rogpeppe/go-internal v1.11.0 // indirect
4143
github.com/spf13/pflag v1.0.5 // indirect
44+
github.com/yuin/gopher-lua v1.1.1 // indirect
4245
golang.org/x/sync v0.4.0 // indirect
4346
google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect
4447
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy
5555
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
5656
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
5757
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
58+
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
59+
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
60+
github.com/alicebob/miniredis v2.5.0+incompatible h1:yBHoLpsyjupjz3NL3MhKMVkR41j82Yjf3KFv7ApYzUI=
61+
github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk=
5862
github.com/apex/log v1.9.0 h1:FHtw/xuaM8AgmvDDTI9fiwoAL25Sq2cxojnZICUU8l0=
5963
github.com/apex/log v1.9.0/go.mod h1:m82fZlWIuiWzWP04XCTXmnX0xRkYYbCdYn8jbJeLBEA=
6064
github.com/apex/logs v1.0.0/go.mod h1:XzxuLZ5myVHDy9SAmYpamKKRNApGj54PfYLcFrXqDwo=
@@ -360,6 +364,8 @@ github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKw
360364
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
361365
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
362366
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
367+
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
368+
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
363369
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
364370
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
365371
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=

handler/handler.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ type Client struct {
5252
PrometheusClient
5353
targetTmpl *template.Template
5454
agentLimits limits.Agents
55+
ipLimiter *limits.RateLimiter
5556
}
5657

5758
// LocatorV2 defines how the Nearest handler requests machines nearest to the
@@ -84,7 +85,8 @@ func init() {
8485
}
8586

8687
// NewClient creates a new client.
87-
func NewClient(project string, private Signer, locatorV2 LocatorV2, client ClientLocator, prom PrometheusClient, lmts limits.Agents) *Client {
88+
func NewClient(project string, private Signer, locatorV2 LocatorV2, client ClientLocator,
89+
prom PrometheusClient, lmts limits.Agents, limiter *limits.RateLimiter) *Client {
8890
return &Client{
8991
Signer: private,
9092
project: project,
@@ -93,6 +95,7 @@ func NewClient(project string, private Signer, locatorV2 LocatorV2, client Clien
9395
PrometheusClient: prom,
9496
targetTmpl: template.Must(template.New("name").Parse("{{.Hostname}}{{.Ports}}")),
9597
agentLimits: lmts,
98+
ipLimiter: limiter,
9699
}
97100
}
98101

@@ -153,6 +156,38 @@ func (c *Client) Nearest(rw http.ResponseWriter, req *http.Request) {
153156
return
154157
}
155158

159+
// Check rate limit for IP and UA.
160+
if c.ipLimiter != nil {
161+
// Get the IP address from the request. X-Forwarded-For is guaranteed to
162+
// be set by AppEngine.
163+
ip := req.Header.Get("X-Forwarded-For")
164+
ips := strings.Split(ip, ",")
165+
if len(ips) > 0 {
166+
ip = strings.TrimSpace(ips[0])
167+
}
168+
if ip != "" {
169+
// An empty UA is technically possible. In this case, the key will be
170+
// "ip:" and the rate limiting will be based on the IP address only.
171+
ua := req.Header.Get("User-Agent")
172+
limited, err := c.ipLimiter.IsLimited(ip, ua)
173+
if err != nil {
174+
// Log error but don't block request (fail open).
175+
// TODO: Add tests for this path.
176+
log.Printf("Rate limiter error: %v", err)
177+
} else if limited {
178+
metrics.RequestsTotal.WithLabelValues("nearest", "rate limit",
179+
http.StatusText(result.Error.Status)).Inc()
180+
// For now, we only log the rate limit exceeded message.
181+
// TODO: Actually block the request and return an appropriate HTTP error
182+
// code and message.
183+
log.Printf("Rate limit exceeded for IP %s and UA %s", ip, ua)
184+
}
185+
} else {
186+
// This should never happen if Locate is deployed on AppEngine.
187+
log.Println("Cannot find IP address for rate limiting.")
188+
}
189+
}
190+
156191
experiment, service := getExperimentAndService(req.URL.Path)
157192

158193
// Look up client location.

handler/handler_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func TestClient_Nearest(t *testing.T) {
212212
if tt.cl == nil {
213213
tt.cl = clientgeo.NewAppEngineLocator()
214214
}
215-
c := NewClient(tt.project, tt.signer, tt.locator, tt.cl, prom.NewAPI(nil), tt.limits)
215+
c := NewClient(tt.project, tt.signer, tt.locator, tt.cl, prom.NewAPI(nil), tt.limits, nil)
216216

217217
mux := http.NewServeMux()
218218
mux.HandleFunc("/v2/nearest/", c.Nearest)
@@ -291,7 +291,7 @@ func TestClient_Ready(t *testing.T) {
291291
}
292292
for _, tt := range tests {
293293
t.Run(tt.name, func(t *testing.T) {
294-
c := NewClient("foo", &fakeSigner{}, &fakeLocatorV2{StatusTracker: &heartbeattest.FakeStatusTracker{Err: tt.fakeErr}}, nil, nil, nil)
294+
c := NewClient("foo", &fakeSigner{}, &fakeLocatorV2{StatusTracker: &heartbeattest.FakeStatusTracker{Err: tt.fakeErr}}, nil, nil, nil, nil)
295295

296296
mux := http.NewServeMux()
297297
mux.HandleFunc("/ready/", c.Ready)
@@ -349,7 +349,7 @@ func TestClient_Registrations(t *testing.T) {
349349
}
350350

351351
t.Run(tt.name, func(t *testing.T) {
352-
c := NewClient("foo", &fakeSigner{}, &fakeLocatorV2{StatusTracker: fakeStatusTracker}, nil, nil, nil)
352+
c := NewClient("foo", &fakeSigner{}, &fakeLocatorV2{StatusTracker: fakeStatusTracker}, nil, nil, nil, nil)
353353

354354
mux := http.NewServeMux()
355355
mux.HandleFunc("/v2/siteinfo/registrations/", c.Registrations)

handler/heartbeat_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestClient_handleHeartbeats(t *testing.T) {
7070
func fakeClient(t heartbeat.StatusTracker) *Client {
7171
locatorv2 := fakeLocatorV2{StatusTracker: t}
7272
return NewClient("mlab-sandbox", &fakeSigner{}, &locatorv2,
73-
clientgeo.NewAppEngineLocator(), prom.NewAPI(nil), nil)
73+
clientgeo.NewAppEngineLocator(), prom.NewAPI(nil), nil, nil)
7474
}
7575

7676
type fakeConn struct {

handler/monitoring_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func TestClient_Monitoring(t *testing.T) {
9292
for _, tt := range tests {
9393
t.Run(tt.name, func(t *testing.T) {
9494
cl := clientgeo.NewAppEngineLocator()
95-
c := NewClient("mlab-sandbox", tt.signer, tt.locator, cl, prom.NewAPI(nil), nil)
95+
c := NewClient("mlab-sandbox", tt.signer, tt.locator, cl, prom.NewAPI(nil), nil, nil)
9696
rw := httptest.NewRecorder()
9797
req := httptest.NewRequest(http.MethodGet, "/v2/platform/monitoring/"+tt.path, nil)
9898
req = req.Clone(controller.SetClaim(req.Context(), tt.claim))

limits/ratelimiter.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package limits
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"time"
7+
8+
"github.com/gomodule/redigo/redis"
9+
)
10+
11+
// RateLimitConfig holds the configuration for IP+UA rate limiting.
12+
type RateLimitConfig struct {
13+
// Interval defines the duration of the sliding window.
14+
Interval time.Duration
15+
// MaxEvents defines the maximum number of events allowed in the interval.
16+
MaxEvents int
17+
// KeyPrefix is the prefix for Redis keys.
18+
KeyPrefix string
19+
}
20+
21+
// RateLimiter implements a distributed rate limiter using Redis sorted sets (ZSET).
22+
// It maintains a sliding window of events for each IP+UA combination, where:
23+
// - Each event is stored in a ZSET with the timestamp as score
24+
// - Old events (outside the window) are automatically removed
25+
// - Keys automatically expire after the configured interval
26+
//
27+
// The limiter considers a request to be rate-limited if the number of events
28+
// in the current window exceeds MaxEvents.
29+
type RateLimiter struct {
30+
pool *redis.Pool
31+
interval time.Duration
32+
maxEvents int
33+
keyPrefix string
34+
}
35+
36+
// NewRateLimiter creates a new rate limiter.
37+
func NewRateLimiter(pool *redis.Pool, config RateLimitConfig) *RateLimiter {
38+
return &RateLimiter{
39+
pool: pool,
40+
interval: config.Interval,
41+
maxEvents: config.MaxEvents,
42+
keyPrefix: config.KeyPrefix,
43+
}
44+
}
45+
46+
// generateKey creates a Redis key from IP and User-Agent.
47+
func (rl *RateLimiter) generateKey(ip, ua string) string {
48+
return fmt.Sprintf("%s:%s:%s", rl.keyPrefix, ip, ua)
49+
}
50+
51+
// IsLimited checks if the given IP and User-Agent combination should be rate limited.
52+
func (rl *RateLimiter) IsLimited(ip, ua string) (bool, error) {
53+
conn := rl.pool.Get()
54+
defer conn.Close()
55+
56+
now := time.Now().UnixMicro()
57+
windowStart := now - rl.interval.Microseconds()
58+
redisKey := rl.generateKey(ip, ua)
59+
60+
// Send all commands in pipeline.
61+
// 1. Remove events outside the window
62+
conn.Send("ZREMRANGEBYSCORE", redisKey, "-inf", windowStart)
63+
// 2. Add current event
64+
conn.Send("ZADD", redisKey, now, strconv.FormatInt(now, 10))
65+
// 3. Set key expiration
66+
conn.Send("EXPIRE", redisKey, int64(rl.interval.Seconds()))
67+
// 4. Get total event count
68+
conn.Send("ZCARD", redisKey)
69+
70+
// Flush pipeline
71+
if err := conn.Flush(); err != nil {
72+
return false, fmt.Errorf("failed to flush pipeline: %w", err)
73+
}
74+
75+
// Receive all replies
76+
for i := 0; i < 3; i++ {
77+
// Receive replies for ZREMRANGEBYSCORE, ZADD, and EXPIRE
78+
if _, err := conn.Receive(); err != nil {
79+
return false, fmt.Errorf("failed to receive reply %d: %w", i, err)
80+
}
81+
}
82+
83+
// Receive and process ZCARD reply
84+
count, err := redis.Int64(conn.Receive())
85+
if err != nil {
86+
return false, fmt.Errorf("failed to receive count: %w", err)
87+
}
88+
89+
return count > int64(rl.maxEvents), nil
90+
}

0 commit comments

Comments
 (0)