Skip to content

Commit 55fdbc1

Browse files
authored
Feat/weighted rate limiting (#16)
1 parent 13f9119 commit 55fdbc1

File tree

7 files changed

+234
-69
lines changed

7 files changed

+234
-69
lines changed

src/limiter.ts

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,30 @@ export class Limiter implements LimiterStoreContract {
5353
*
5454
* @param key - Unique identifier for the rate limit
5555
*/
56-
consume(key: string | number): Promise<LimiterResponse> {
57-
return this.#store.consume(key)
56+
consume(key: string | number, amount?: number): Promise<LimiterResponse> {
57+
return this.#store.consume(key, amount)
5858
}
5959

6060
/**
6161
* Increments the consumed request count for the given key.
6262
* Unlike consume(), this method does not throw when the limit is reached.
6363
*
6464
* @param key - Unique identifier for the rate limit
65+
* @param amount - Number of requests to increment (default: 1)
6566
*/
66-
increment(key: string | number): Promise<LimiterResponse> {
67-
return this.#store.increment(key)
67+
increment(key: string | number, amount?: number): Promise<LimiterResponse> {
68+
return this.#store.increment(key, amount)
6869
}
6970

7071
/**
7172
* Decrements the consumed request count for the given key.
7273
* Will not decrement below zero.
7374
*
7475
* @param key - Unique identifier for the rate limit
76+
* @param amount - Number of requests to decrement (default: 1)
7577
*/
76-
decrement(key: string | number): Promise<LimiterResponse> {
77-
return this.#store.decrement(key)
78+
decrement(key: string | number, amount?: number): Promise<LimiterResponse> {
79+
return this.#store.decrement(key, amount)
7880
}
7981

8082
/**
@@ -95,7 +97,11 @@ export class Limiter implements LimiterStoreContract {
9597
* }
9698
* ```
9799
*/
98-
async attempt<T>(key: string | number, callback: () => T | Promise<T>): Promise<T | undefined> {
100+
async attempt<T>(
101+
key: string | number,
102+
callback: () => T | Promise<T>,
103+
amount?: number
104+
): Promise<T | undefined> {
99105
/**
100106
* Return early when remaining requests are less than
101107
* zero.
@@ -110,7 +116,7 @@ export class Limiter implements LimiterStoreContract {
110116
}
111117

112118
try {
113-
await this.consume(key)
119+
await this.consume(key, amount)
114120
return callback()
115121
} catch (error) {
116122
if (error instanceof E_TOO_MANY_REQUESTS === false) {
@@ -144,7 +150,8 @@ export class Limiter implements LimiterStoreContract {
144150
*/
145151
async penalize<T>(
146152
key: string | number,
147-
callback: () => T | Promise<T>
153+
callback: () => T | Promise<T>,
154+
amount?: number
148155
): Promise<[null, T] | [ThrottleException, null]> {
149156
const response = await this.get(key)
150157

@@ -169,7 +176,7 @@ export class Limiter implements LimiterStoreContract {
169176
* an error.
170177
*/
171178
if (callbackError) {
172-
const { consumed, limit } = await this.increment(key)
179+
const { consumed, limit } = await this.increment(key, amount)
173180
if (consumed >= limit && this.blockDuration) {
174181
await this.block(key, this.blockDuration)
175182
}

src/stores/bridge.ts

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,12 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
9191
* console.log(`Remaining: ${response.remaining}`)
9292
* ```
9393
*/
94-
async consume(key: string | number): Promise<LimiterResponse> {
94+
async consume(key: string | number, amount?: number): Promise<LimiterResponse> {
95+
const consumeAmount = amount !== undefined && amount > 0 ? amount : 1
96+
9597
try {
96-
const response = await this.rateLimiter.consume(key, 1)
97-
debug('request consumed for key %s', key)
98+
const response = await this.rateLimiter.consume(key, consumeAmount)
99+
debug('request consumed for key %s with amount %d', key, consumeAmount)
98100
return this.makeLimiterResponse(response)
99101
} catch (errorResponse: unknown) {
100102
debug('unable to consume request for key %s, %O', key, errorResponse)
@@ -117,8 +119,13 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
117119
* const response = await limiter.increment('user:123')
118120
* ```
119121
*/
120-
async increment(key: string | number): Promise<LimiterResponse> {
121-
const response = await this.rateLimiter.penalty(key, 1)
122+
async increment(key: string | number, amount: number = 1): Promise<LimiterResponse> {
123+
if (amount <= 0) {
124+
debug('invalid increment amount "%d" provided. Falling back to 1', amount)
125+
amount = 1
126+
}
127+
128+
const response = await this.rateLimiter.penalty(key, amount)
122129
debug('increased requests count for key %s', key)
123130

124131
return this.makeLimiterResponse(response)
@@ -135,7 +142,7 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
135142
* const response = await limiter.decrement('user:123')
136143
* ```
137144
*/
138-
async decrement(key: string | number): Promise<LimiterResponse> {
145+
async decrement(key: string | number, amount: number = 1): Promise<LimiterResponse> {
139146
const existingKey = await this.rateLimiter.get(key)
140147

141148
/**
@@ -145,17 +152,26 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
145152
return this.set(key, 0, this.duration)
146153
}
147154

155+
if (amount <= 0) {
156+
debug('invalid decrement amount "%d" provided. Falling back to 1', amount)
157+
amount = 1
158+
}
159+
148160
/**
149161
* Do not decrement beyond zero
150162
*/
151163
if (existingKey.consumedPoints <= 0) {
152164
return this.makeLimiterResponse(existingKey)
153165
}
154166

167+
if (amount > existingKey.consumedPoints) {
168+
amount = existingKey.consumedPoints
169+
}
170+
155171
/**
156172
* Decrement
157173
*/
158-
const response = await this.rateLimiter.reward(key, 1)
174+
const response = await this.rateLimiter.reward(key, amount)
159175
debug('decreased requests count for key %s', key)
160176

161177
return this.makeLimiterResponse(response)

src/types.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,18 @@ export interface LimiterStoreContract {
186186
* when all the requests have already been consumed or if
187187
* the key is blocked.
188188
*/
189-
consume(key: string | number): Promise<LimiterResponse>
189+
consume(key: string | number, amount?: number): Promise<LimiterResponse>
190190

191191
/**
192192
* Increment the number of consumed requests for a given key.
193193
* No errors are thrown when limit has reached
194194
*/
195-
increment(key: string | number): Promise<LimiterResponse>
195+
increment(key: string | number, amount?: number): Promise<LimiterResponse>
196196

197197
/**
198198
* Decrement the number of consumed requests for a given key.
199199
*/
200-
decrement(key: string | number): Promise<LimiterResponse>
200+
decrement(key: string | number, amount?: number): Promise<LimiterResponse>
201201

202202
/**
203203
* Block a given key for the given duration. The duration must be

tests/limiter.spec.ts

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,27 @@ test.group('Limiter', () => {
3434
*/
3535
const consumeCall = sinon.spy(store, 'consume')
3636
await limiter.consume('ip_localhost')
37-
assert.isTrue(consumeCall.calledOnceWithExactly('ip_localhost'), 'consume called')
37+
assert.isTrue(consumeCall.calledOnceWithExactly('ip_localhost', undefined), 'consume called')
3838

3939
/**
4040
* increment call
4141
*/
4242
const incrementCall = sinon.spy(store, 'increment')
4343
await limiter.increment('ip_localhost')
44-
assert.isTrue(incrementCall.calledOnceWithExactly('ip_localhost'), 'increment called')
44+
assert.isTrue(
45+
incrementCall.calledOnceWithExactly('ip_localhost', undefined),
46+
'increment called'
47+
)
4548

4649
/**
4750
* decrement call
4851
*/
4952
const decrementCall = sinon.spy(store, 'decrement')
5053
await limiter.decrement('ip_localhost')
51-
assert.isTrue(decrementCall.calledOnceWithExactly('ip_localhost'), 'decrement called')
54+
assert.isTrue(
55+
decrementCall.calledOnceWithExactly('ip_localhost', undefined),
56+
'decrement called'
57+
)
5258

5359
/**
5460
* get call
@@ -104,6 +110,142 @@ test.group('Limiter', () => {
104110
await assert.doesNotReject(() => limiter.increment('ip_localhost'))
105111
})
106112

113+
test('increment requests count with negative amount should default to 1', async ({ assert }) => {
114+
const redis = createRedis(['rlflx:ip_localhost']).connection()
115+
const store = new LimiterRedisStore(redis, {
116+
duration: '1 minute',
117+
requests: 5,
118+
})
119+
120+
const limiter = new Limiter(store)
121+
122+
await limiter.increment('ip_localhost', -5)
123+
assert.equal(await limiter.remaining('ip_localhost'), 4)
124+
})
125+
126+
test('decrement requests count with negative amount should default to 1', async ({ assert }) => {
127+
const redis = createRedis(['rlflx:ip_localhost']).connection()
128+
const store = new LimiterRedisStore(redis, {
129+
duration: '1 minute',
130+
requests: 5,
131+
})
132+
133+
const limiter = new Limiter(store)
134+
135+
await limiter.increment('ip_localhost', 5)
136+
await limiter.decrement('ip_localhost', -3)
137+
assert.equal(await limiter.remaining('ip_localhost'), 1)
138+
})
139+
140+
test('increment requests count with zero amount should default to 1', async ({ assert }) => {
141+
const redis = createRedis(['rlflx:ip_localhost']).connection()
142+
const store = new LimiterRedisStore(redis, {
143+
duration: '1 minute',
144+
requests: 5,
145+
})
146+
147+
const limiter = new Limiter(store)
148+
149+
await limiter.increment('ip_localhost', 0)
150+
assert.equal(await limiter.remaining('ip_localhost'), 4)
151+
})
152+
153+
test('decrement requests count with zero amount should default to 1', async ({ assert }) => {
154+
const redis = createRedis(['rlflx:ip_localhost']).connection()
155+
const store = new LimiterRedisStore(redis, {
156+
duration: '1 minute',
157+
requests: 5,
158+
})
159+
160+
const limiter = new Limiter(store)
161+
162+
await limiter.increment('ip_localhost', 5)
163+
await limiter.decrement('ip_localhost', 0)
164+
assert.equal(await limiter.remaining('ip_localhost'), 1)
165+
})
166+
167+
test('increment remaining requests by amount', async ({ assert }) => {
168+
const redis = createRedis(['rlflx:ip_localhost']).connection()
169+
const store = new LimiterRedisStore(redis, {
170+
duration: '1 minute',
171+
requests: 5,
172+
})
173+
174+
const limiter = new Limiter(store)
175+
176+
await limiter.increment('ip_localhost', 3)
177+
const response = await limiter.get('ip_localhost')
178+
assert.containsSubset(response, {
179+
consumed: 3,
180+
remaining: 2,
181+
limit: 5,
182+
})
183+
})
184+
185+
test('decrement consumed requests by amount', async ({ assert }) => {
186+
const redis = createRedis(['rlflx:ip_localhost']).connection()
187+
const store = new LimiterRedisStore(redis, {
188+
duration: '1 minute',
189+
requests: 5,
190+
})
191+
192+
const limiter = new Limiter(store)
193+
194+
await limiter.increment('ip_localhost', 4)
195+
await limiter.decrement('ip_localhost', 2)
196+
const response = await limiter.get('ip_localhost')
197+
assert.containsSubset(response, {
198+
consumed: 2,
199+
remaining: 3,
200+
limit: 5,
201+
})
202+
})
203+
204+
test('consume remaining requests by amount', async ({ assert }) => {
205+
const redis = createRedis(['rlflx:ip_localhost']).connection()
206+
const store = new LimiterRedisStore(redis, {
207+
duration: '1 minute',
208+
requests: 5,
209+
})
210+
211+
const limiter = new Limiter(store)
212+
213+
await limiter.consume('ip_localhost', 3)
214+
const response = await limiter.get('ip_localhost')
215+
assert.containsSubset(response, {
216+
consumed: 3,
217+
remaining: 2,
218+
limit: 5,
219+
})
220+
})
221+
222+
test('increment requests count with a custom amount', async ({ assert }) => {
223+
const redis = createRedis(['rlflx:ip_localhost']).connection()
224+
const store = new LimiterRedisStore(redis, {
225+
duration: '1 minute',
226+
requests: 10,
227+
})
228+
229+
const limiter = new Limiter(store)
230+
231+
await limiter.increment('ip_localhost', 3)
232+
assert.equal(await limiter.remaining('ip_localhost'), 7)
233+
})
234+
235+
test('decrement requests count with a custom amount', async ({ assert }) => {
236+
const redis = createRedis(['rlflx:ip_localhost']).connection()
237+
const store = new LimiterRedisStore(redis, {
238+
duration: '1 minute',
239+
requests: 10,
240+
})
241+
242+
const limiter = new Limiter(store)
243+
244+
await limiter.increment('ip_localhost', 10)
245+
await limiter.decrement('ip_localhost', 4)
246+
assert.equal(await limiter.remaining('ip_localhost'), 4)
247+
})
248+
107249
test('do not run action when all requests have been exhausted', async ({ assert }) => {
108250
const executionStack: string[] = []
109251
const redis = createRedis(['rlflx:ip_localhost']).connection()

0 commit comments

Comments
 (0)