Skip to content

Commit 2e5c20c

Browse files
Enable rate limiting (#214)
* Actually block when IP+UA is rate limited. * Add tests for the rate limiter code paths
1 parent 80558ef commit 2e5c20c

File tree

2 files changed

+108
-7
lines changed

2 files changed

+108
-7
lines changed

handler/handler.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ type Signer interface {
4343
Sign(cl jwt.Claims) (string, error)
4444
}
4545

46+
type Limiter interface {
47+
IsLimited(ip, ua string) (bool, error)
48+
}
49+
4650
// Client contains state needed for xyz.
4751
type Client struct {
4852
Signer
@@ -52,7 +56,7 @@ type Client struct {
5256
PrometheusClient
5357
targetTmpl *template.Template
5458
agentLimits limits.Agents
55-
ipLimiter *limits.RateLimiter
59+
ipLimiter Limiter
5660
}
5761

5862
// LocatorV2 defines how the Nearest handler requests machines nearest to the
@@ -86,7 +90,7 @@ func init() {
8690

8791
// NewClient creates a new client.
8892
func NewClient(project string, private Signer, locatorV2 LocatorV2, client ClientLocator,
89-
prom PrometheusClient, lmts limits.Agents, limiter *limits.RateLimiter) *Client {
93+
prom PrometheusClient, lmts limits.Agents, limiter Limiter) *Client {
9094
return &Client{
9195
Signer: private,
9296
project: project,
@@ -175,12 +179,13 @@ func (c *Client) Nearest(rw http.ResponseWriter, req *http.Request) {
175179
// TODO: Add tests for this path.
176180
log.Printf("Rate limiter error: %v", err)
177181
} else if limited {
182+
// Log IP and UA and block the request.
183+
result.Error = v2.NewError("client", tooManyRequests, http.StatusTooManyRequests)
178184
metrics.RequestsTotal.WithLabelValues("nearest", "rate limit",
179-
http.StatusText(http.StatusTooManyRequests)).Inc()
180-
// For now, we only log the rate limit exceeded message.
181-
// TODO: Actually block the request and return an appropriate HTTP error
182-
// code and message.
185+
http.StatusText(result.Error.Status)).Inc()
183186
log.Printf("Rate limit exceeded for IP %s and UA %s", ip, ua)
187+
writeResult(rw, result.Error.Status, &result)
188+
return
184189
}
185190
} else {
186191
// This should never happen if Locate is deployed on AppEngine.

handler/handler_test.go

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ func (l *fakeAppEngineLocator) Locate(req *http.Request) (*clientgeo.Location, e
7171
return l.loc, l.err
7272
}
7373

74+
type fakeRateLimiter struct {
75+
limited bool
76+
err error
77+
}
78+
79+
func (r *fakeRateLimiter) IsLimited(ip, ua string) (bool, error) {
80+
if r.err != nil {
81+
return false, r.err
82+
}
83+
return r.limited, nil
84+
}
85+
7486
func TestClient_Nearest(t *testing.T) {
7587
tests := []struct {
7688
name string
@@ -81,6 +93,7 @@ func TestClient_Nearest(t *testing.T) {
8193
project string
8294
latlon string
8395
limits limits.Agents
96+
ipLimiter Limiter
8497
header http.Header
8598
wantLatLon string
8699
wantKey string
@@ -206,13 +219,96 @@ func TestClient_Nearest(t *testing.T) {
206219
wantKey: "ws://:3001/ndt_protocol",
207220
wantStatus: http.StatusOK,
208221
},
222+
{
223+
name: "error-rate-limit-exceeded",
224+
path: "ndt/ndt5",
225+
signer: &fakeSigner{},
226+
locator: &fakeLocatorV2{
227+
targets: []v2.Target{{Machine: "mlab1-lga0t.measurement-lab.org"}},
228+
},
229+
header: http.Header{
230+
"X-Forwarded-For": []string{"192.0.2.1"},
231+
"User-Agent": []string{"test-client"},
232+
},
233+
ipLimiter: &fakeRateLimiter{
234+
limited: true,
235+
},
236+
wantStatus: http.StatusTooManyRequests,
237+
},
238+
{
239+
name: "success-rate-limit-not-exceeded",
240+
path: "ndt/ndt5",
241+
signer: &fakeSigner{},
242+
locator: &fakeLocatorV2{
243+
targets: []v2.Target{{Machine: "mlab1-lga0t.measurement-lab.org"}},
244+
urls: []url.URL{
245+
{Scheme: "ws", Host: ":3001", Path: "/ndt_protocol"},
246+
{Scheme: "wss", Host: ":3010", Path: "ndt_protocol"},
247+
},
248+
},
249+
header: http.Header{
250+
"X-AppEngine-CityLatLong": []string{"40.3,-70.4"},
251+
"X-Forwarded-For": []string{"192.168.1.1"},
252+
"User-Agent": []string{"test-client"},
253+
},
254+
ipLimiter: &fakeRateLimiter{
255+
limited: false,
256+
},
257+
wantLatLon: "40.3,-70.4",
258+
wantKey: "ws://:3001/ndt_protocol",
259+
wantStatus: http.StatusOK,
260+
},
261+
{
262+
name: "success-rate-limiter-error",
263+
path: "ndt/ndt5",
264+
signer: &fakeSigner{},
265+
locator: &fakeLocatorV2{
266+
targets: []v2.Target{{Machine: "mlab1-lga0t.measurement-lab.org"}},
267+
urls: []url.URL{
268+
{Scheme: "ws", Host: ":3001", Path: "/ndt_protocol"},
269+
{Scheme: "wss", Host: ":3010", Path: "/ndt_protocol"},
270+
},
271+
},
272+
header: http.Header{
273+
"X-AppEngine-CityLatLong": []string{"40.3,-70.4"},
274+
"X-Forwarded-For": []string{"192.168.1.1"},
275+
"User-Agent": []string{"test-client"},
276+
},
277+
ipLimiter: &fakeRateLimiter{
278+
err: errors.New("redis error"),
279+
},
280+
wantLatLon: "40.3,-70.4",
281+
wantKey: "ws://:3001/ndt_protocol",
282+
wantStatus: http.StatusOK, // Should fail open
283+
},
284+
{
285+
name: "success-missing-forwarded-for",
286+
path: "ndt/ndt5",
287+
signer: &fakeSigner{},
288+
locator: &fakeLocatorV2{
289+
targets: []v2.Target{{Machine: "mlab1-lga0t.measurement-lab.org"}},
290+
urls: []url.URL{
291+
{Scheme: "ws", Host: ":3001", Path: "/ndt_protocol"},
292+
{Scheme: "wss", Host: ":3010", Path: "/ndt_protocol"},
293+
},
294+
},
295+
header: http.Header{
296+
"X-AppEngine-CityLatLong": []string{"40.3,-70.4"},
297+
// No X-Forwarded-For
298+
"User-Agent": []string{"test-client"},
299+
},
300+
ipLimiter: &fakeRateLimiter{limited: false},
301+
wantLatLon: "40.3,-70.4",
302+
wantKey: "ws://:3001/ndt_protocol",
303+
wantStatus: http.StatusOK,
304+
},
209305
}
210306
for _, tt := range tests {
211307
t.Run(tt.name, func(t *testing.T) {
212308
if tt.cl == nil {
213309
tt.cl = clientgeo.NewAppEngineLocator()
214310
}
215-
c := NewClient(tt.project, tt.signer, tt.locator, tt.cl, prom.NewAPI(nil), tt.limits, nil)
311+
c := NewClient(tt.project, tt.signer, tt.locator, tt.cl, prom.NewAPI(nil), tt.limits, tt.ipLimiter)
216312

217313
mux := http.NewServeMux()
218314
mux.HandleFunc("/v2/nearest/", c.Nearest)

0 commit comments

Comments
 (0)