Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to manually check rate limit (#346) #392

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,80 @@ fastify.get('/private', {
})
```

### Manual Rate Limit

In some situations, the default behavior of this plugin may not be sufficient and you want to implement your own rate limit strategy.
This might be the case as well if you want to integrate it with other technologies such as [GraphQL](https://graphql.org/) or [tRPC](https://trpc.io/).
You can create your own limiter function with `fastify.createRateLimit()`.
By default, this limiter function uses the global [options](#options) you defined during plugin registration.
You can override some or all of your global options by passing them to `createRateLimit()`.
The following options can be overridden: `store`, `skipOnError`, `max`, `timeWindow`, `allowList`, `keyGenerator`, and `ban`.

The next example demonstrates the usage of a custom rate limiter:

```ts
import Fastify from 'fastify'

const fastify = Fastify()

// register with global options
await fastify.register(import('@fastify/rate-limit'), {
max: 100,
timeWindow: '1 minute'
})

// checkRateLimit will use the global options provided above when called
const checkRateLimit = fastify.createRateLimit();

fastify.get("/", async (request, reply) => {
// manually check the rate limit (using global options)
const limit = await checkRateLimit(request);

if(!limit.isAllowed && limit.isExceeded) {
return reply.code(429).send("Limit exceeded");
}

return reply.send("Hello world");
});

// override global max option
const checkCustomRateLimit = fastify.createRateLimit({ max: 100 });

fastify.get("/custom", async (request, reply) => {
// manually check the rate limit (using global options and overridden max option)
const limit = await checkCustomRateLimit(request);

// manually handle limit exceedance
if(!limit.isAllowed && limit.isExceeded) {
return reply.code(429).send("Limit exceeded");
}

return reply.send("Hello world");
});
```

A custom limiter function created with `fastify.createRateLimit()` only requires a `FastifyRequest` as the first parameter:

```ts
const checkRateLimit = fastify.createRateLimit();
const limit = await checkRateLimit(request);
```

The returned `limit` is an object containing the following properties for the `request` passed to `checkRateLimit`.

- `isAllowed`: if `true`, the request was excluded from rate limiting according to the configured `allowList`.
- `key`: the generated key as returned by the `keyGenerator` function.

If `isAllowed` is `false` the object also contains these additional properties:

- `max`: the configured `max` option as a number. If a `max` function was supplied as global option or to `fastify.createRateLimit()`, this property will correspond to the function's return type for the given `request`.
- `timeWindow`: the configured `timeWindow` option in milliseconds. If a function was supplied to `timeWindow`, similar to the `max` property above, this property will be equal to the function's return type.
- `remaining`: the remaining amount of requests before the limit is exceeded.
- `ttl`: the remaining time until the limit will be reset in milliseconds.
- `ttlInSeconds`: `ttl` in seconds.
- `isExceeded`: `true` if the limit was exceeded.
- `isBanned`: `true` if the request was banned according to the `ban` option.

### Examples of Custom Store

These examples show an overview of the `store` feature and you should take inspiration from it and tweak as you need:
Expand Down
150 changes: 100 additions & 50 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,17 @@ async function fastifyRateLimit (fastify, settings) {

fastify.decorateRequest(pluginComponent.rateLimitRan, false)

if (!fastify.hasDecorator('createRateLimit')) {
fastify.decorate('createRateLimit', (options) => {
const args = createLimiterArgs(pluginComponent, globalParams, options)
return (req) => applyRateLimit(...args, req)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of decostruct args we can:

Suggested change
return (req) => applyRateLimit(...args, req)
return (req) => applyRateLimit.apply(this, args.concat(req))

(and even better, avoid the concat

})
}

if (!fastify.hasDecorator('rateLimit')) {
fastify.decorate('rateLimit', (options) => {
if (typeof options === 'object') {
const newPluginComponent = Object.create(pluginComponent)
const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} })
newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams)
return rateLimitRequestHandler(newPluginComponent, mergedRateLimitParams)
}

return rateLimitRequestHandler(pluginComponent, globalParams)
const args = createLimiterArgs(pluginComponent, globalParams, options)
return rateLimitRequestHandler(...args)
})
}

Expand Down Expand Up @@ -186,6 +187,17 @@ function mergeParams (...params) {
return result
}

function createLimiterArgs (pluginComponent, globalParams, options) {
if (typeof options === 'object') {
const newPluginComponent = Object.create(pluginComponent)
const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} })
newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams)
return [newPluginComponent, mergedRateLimitParams]
}

return [pluginComponent, globalParams]
}

function addRouteRateHook (pluginComponent, params, routeOptions) {
const hook = params.hook
const hookHandler = rateLimitRequestHandler(pluginComponent, params)
Expand All @@ -198,8 +210,72 @@ function addRouteRateHook (pluginComponent, params, routeOptions) {
}
}

async function applyRateLimit (pluginComponent, params, req) {
const { store } = pluginComponent

// Retrieve the key from the generator (the global one or the one defined in the endpoint)
let key = await params.keyGenerator(req)
const groupId = req.routeOptions.config?.rateLimit?.groupId

if (groupId) {
key += groupId
}

// Don't apply any rate limiting if in the allow list
if (params.allowList) {
if (typeof params.allowList === 'function') {
if (await params.allowList(req, key)) {
return {
isAllowed: true,
key
}
}
} else if (params.allowList.indexOf(key) !== -1) {
return {
isAllowed: true,
key
}
}
}

const max = typeof params.max === 'number' ? params.max : await params.max(req, key)
const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key)
let current = 0
let ttl = 0
let ttlInSeconds = 0

// We increment the rate limit for the current request
try {
const res = await new Promise((resolve, reject) => {
store.incr(key, (err, res) => {
err ? reject(err) : resolve(res)
}, timeWindow, max)
})

current = res.current
ttl = res.ttl
ttlInSeconds = Math.ceil(res.ttl / 1000)
} catch (err) {
if (!params.skipOnError) {
throw err
}
}

return {
isAllowed: false,
key,
max,
timeWindow,
remaining: Math.max(0, max - current),
ttl,
ttlInSeconds,
isExceeded: current > max,
isBanned: params.ban !== -1 && current - max > params.ban
}
}

function rateLimitRequestHandler (pluginComponent, params) {
const { rateLimitRan, store } = pluginComponent
const { rateLimitRan } = pluginComponent

return async (req, res) => {
if (req[rateLimitRan]) {
Expand All @@ -208,50 +284,24 @@ function rateLimitRequestHandler (pluginComponent, params) {

req[rateLimitRan] = true

// Retrieve the key from the generator (the global one or the one defined in the endpoint)
let key = await params.keyGenerator(req)
const groupId = req.routeOptions.config?.rateLimit?.groupId
if (groupId) {
key += groupId
}

// Don't apply any rate limiting if in the allow list
if (params.allowList) {
if (typeof params.allowList === 'function') {
if (await params.allowList(req, key)) {
return
}
} else if (params.allowList.indexOf(key) !== -1) {
return
}
const rateLimit = await applyRateLimit(pluginComponent, params, req)
if (rateLimit.isAllowed) {
return
}

const max = typeof params.max === 'number' ? params.max : await params.max(req, key)
const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key)
let current = 0
let ttl = 0
let ttlInSeconds = 0

// We increment the rate limit for the current request
try {
const res = await new Promise((resolve, reject) => {
store.incr(key, (err, res) => {
err ? reject(err) : resolve(res)
}, timeWindow, max)
})

current = res.current
ttl = res.ttl
ttlInSeconds = Math.ceil(res.ttl / 1000)
} catch (err) {
if (!params.skipOnError) {
throw err
}
}
const {
key,
max,
remaining,
ttl,
ttlInSeconds,
isExceeded,
isBanned
} = rateLimit

if (current <= max) {
if (!isExceeded) {
if (params.addHeadersOnExceeding[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) }
if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, max - current) }
if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, remaining) }
if (params.addHeadersOnExceeding[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) }

params.onExceeding(req, key)
Expand All @@ -274,7 +324,7 @@ function rateLimitRequestHandler (pluginComponent, params) {
after: format(ttlInSeconds * 1000, true)
}

if (params.ban !== -1 && current - max > params.ban) {
if (isBanned) {
respCtx.statusCode = 403
respCtx.ban = true
params.onBanReach(req, key)
Expand Down
Loading