Skip to content

Commit 2274630

Browse files
authored
Implement sliding window rate limiter (#178)
* Implement sliding window rate limiter
1 parent 8f6dc5b commit 2274630

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

pkg/rate/limiter.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package rate
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
)
9+
10+
type Limiter struct {
11+
limit int
12+
period time.Duration
13+
14+
mu sync.Mutex
15+
start, end, size int
16+
requests []time.Time
17+
}
18+
19+
func NewLimiter(limit int, period time.Duration) *Limiter {
20+
return &Limiter{
21+
limit: limit,
22+
period: period,
23+
start: 0,
24+
end: 0,
25+
requests: make([]time.Time, limit),
26+
}
27+
}
28+
29+
func (l *Limiter) Wait(ctx context.Context) error {
30+
return l.WaitN(ctx, 1)
31+
}
32+
33+
func (l *Limiter) WaitN(ctx context.Context, n int) error {
34+
l.mu.Lock()
35+
defer l.mu.Unlock()
36+
37+
if n <= 0 {
38+
return nil
39+
}
40+
41+
if n > l.limit {
42+
return fmt.Errorf("rate: Wait (n=%d) exceed limiter %d", n, l.limit)
43+
}
44+
45+
// Get the oldest request in queue for waiting.
46+
var (
47+
shouldWait bool
48+
oldest time.Time
49+
)
50+
if l.requestsSize()+n > l.limit {
51+
shouldWait = true
52+
oldest, _ = l.requestAt(l.requestsSize() + n - l.limit - 1)
53+
}
54+
55+
// Wait if rate limit is reached.
56+
if shouldWait {
57+
waitDuration := l.period - time.Since(oldest)
58+
if waitDuration > 0 {
59+
timer := time.NewTimer(waitDuration)
60+
defer timer.Stop()
61+
62+
select {
63+
case <-timer.C:
64+
// We can proceed.
65+
case <-ctx.Done():
66+
// Context was canceled before we could proceed.
67+
return ctx.Err()
68+
}
69+
}
70+
}
71+
72+
// Add new requests to queue.
73+
for range n {
74+
l.addRequest(time.Now())
75+
}
76+
77+
return nil
78+
}
79+
80+
func (l *Limiter) requestsSize() int {
81+
return l.size
82+
}
83+
84+
func (l *Limiter) requestAt(i int) (time.Time, bool) {
85+
if i >= l.size {
86+
return time.Now(), false
87+
}
88+
89+
return l.requests[(l.start+i)%l.limit], true
90+
}
91+
92+
func (l *Limiter) addRequest(t time.Time) {
93+
if l.size == l.limit {
94+
l.start++
95+
if l.start >= l.limit {
96+
l.start = 0
97+
}
98+
} else {
99+
l.size++
100+
}
101+
102+
l.requests[l.end] = t
103+
104+
l.end++
105+
if l.end >= l.limit {
106+
l.end = 0
107+
}
108+
}

0 commit comments

Comments
 (0)