Skip to content

Commit e52a7e0

Browse files
SVilgelmdbriemann
andauthored
Add rate limiter to client (#715)
* add rate limiter to client * make rate limiter work for retries * make rate limiter return error instead of blocking * fix test * use RateLimiter interface instead of x/time/rate --------- Co-authored-by: David Linus Briemann <[email protected]>
1 parent 41199c3 commit e52a7e0

File tree

5 files changed

+73
-1
lines changed

5 files changed

+73
-1
lines changed

client.go

+16
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ type Client struct {
152152
errorHooks []ErrorHook
153153
invalidHooks []ErrorHook
154154
panicHooks []ErrorHook
155+
rateLimiter RateLimiter
155156
}
156157

157158
// User type is to hold an username and password information
@@ -920,6 +921,13 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
920921
return c
921922
}
922923

924+
// SetRateLimiter sets an optional `RateLimiter`. If set the rate limiter will control
925+
// all requests made with this client.
926+
func (c *Client) SetRateLimiter(rl RateLimiter) *Client {
927+
c.rateLimiter = rl
928+
return c
929+
}
930+
923931
// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
924932
// compatible interface implementation in the resty client.
925933
//
@@ -1141,6 +1149,14 @@ func (c *Client) execute(req *Request) (*Response, error) {
11411149
}
11421150
}
11431151

1152+
// If there is a rate limiter set for this client, the Execute call
1153+
// will return an error if the rate limit is exceeded.
1154+
if req.client.rateLimiter != nil {
1155+
if !req.client.rateLimiter.Allow() {
1156+
return nil, wrapNoRetryErr(ErrRateLimitExceeded)
1157+
}
1158+
}
1159+
11441160
// resty middlewares
11451161
for _, f := range c.beforeRequest {
11461162
if err = f(c, req); err != nil {

go.mod

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ module github.com/go-resty/resty/v2
22

33
go 1.16
44

5-
require golang.org/x/net v0.15.0
5+
require (
6+
golang.org/x/net v0.15.0
7+
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11
8+
)

go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
3333
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
3434
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
3535
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
36+
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
37+
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
3638
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
3739
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
3840
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

request_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
"strings"
2020
"testing"
2121
"time"
22+
23+
"golang.org/x/time/rate"
2224
)
2325

2426
type AuthSuccess struct {
@@ -66,6 +68,44 @@ func TestGetGH524(t *testing.T) {
6668
assertEqual(t, resp.Request.Header.Get("Content-Type"), "") // unable to reproduce reported issue
6769
}
6870

71+
func TestRateLimiter(t *testing.T) {
72+
ts := createGetServer(t)
73+
defer ts.Close()
74+
75+
// Test a burst with a valid capacity and then a consecutive request that must fail.
76+
77+
// Allow a rate of 1 every 100 ms but also allow bursts of 10 requests.
78+
client := dc().SetRateLimiter(rate.NewLimiter(rate.Every(100*time.Millisecond), 10))
79+
80+
// Execute a burst of 10 requests.
81+
for i := 0; i < 10; i++ {
82+
resp, err := client.R().
83+
SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/")
84+
assertError(t, err)
85+
assertEqual(t, http.StatusOK, resp.StatusCode())
86+
}
87+
// Next request issued directly should fail because burst of 10 has been consumed.
88+
{
89+
_, err := client.R().
90+
SetQueryParam("request_no", strconv.Itoa(11)).Get(ts.URL + "/")
91+
assertErrorIs(t, ErrRateLimitExceeded, err)
92+
}
93+
94+
// Test continues request at a valid rate
95+
96+
// Allow a rate of 1 every ms with no burst.
97+
client = dc().SetRateLimiter(rate.NewLimiter(rate.Every(1*time.Millisecond), 1))
98+
99+
// Sending requests every ms+tiny delta must succeed.
100+
for i := 0; i < 100; i++ {
101+
resp, err := client.R().
102+
SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/")
103+
assertError(t, err)
104+
assertEqual(t, http.StatusOK, resp.StatusCode())
105+
time.Sleep(1*time.Millisecond + 100*time.Microsecond)
106+
}
107+
}
108+
69109
func TestIllegalRetryCount(t *testing.T) {
70110
ts := createGetServer(t)
71111
defer ts.Close()

util.go

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package resty
66

77
import (
88
"bytes"
9+
"errors"
910
"fmt"
1011
"io"
1112
"log"
@@ -64,6 +65,16 @@ func (l *logger) output(format string, v ...interface{}) {
6465
l.l.Printf(format, v...)
6566
}
6667

68+
//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
69+
// Rate Limiter interface
70+
//_______________________________________________________________________
71+
72+
type RateLimiter interface {
73+
Allow() bool
74+
}
75+
76+
var ErrRateLimitExceeded = errors.New("rate limit exceeded")
77+
6778
//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
6879
// Package Helper methods
6980
//_______________________________________________________________________

0 commit comments

Comments
 (0)