diff --git a/README.md b/README.md index 7df58ac..3f3d910 100644 --- a/README.md +++ b/README.md @@ -450,6 +450,80 @@ fastify.get('/private', { }) ``` +### Manual Rate Limit + +In some situations, the default behavior of this plugin may not be sufficient and you want to implement your own rate limit strategy. +This might be the case as well if you want to integrate it with other technologies such as [GraphQL](https://graphql.org/) or [tRPC](https://trpc.io/). +You can create your own limiter function with `fastify.createRateLimit()`. +By default, this limiter function uses the global [options](#options) you defined during plugin registration. +You can override some or all of your global options by passing them to `createRateLimit()`. +The following options can be overridden: `store`, `skipOnError`, `max`, `timeWindow`, `allowList`, `keyGenerator`, and `ban`. + +The next example demonstrates the usage of a custom rate limiter: + +```ts +import Fastify from 'fastify' + +const fastify = Fastify() + +// register with global options +await fastify.register(import('@fastify/rate-limit'), { + max: 100, + timeWindow: '1 minute' +}) + + // checkRateLimit will use the global options provided above when called +const checkRateLimit = fastify.createRateLimit(); + +fastify.get("/", async (request, reply) => { + // manually check the rate limit (using global options) + const limit = await checkRateLimit(request); + + if(!limit.isAllowed && limit.isExceeded) { + return reply.code(429).send("Limit exceeded"); + } + + return reply.send("Hello world"); +}); + +// override global max option +const checkCustomRateLimit = fastify.createRateLimit({ max: 100 }); + +fastify.get("/custom", async (request, reply) => { + // manually check the rate limit (using global options and overridden max option) + const limit = await checkCustomRateLimit(request); + + // manually handle limit exceedance + if(!limit.isAllowed && limit.isExceeded) { + return reply.code(429).send("Limit exceeded"); + } + + return reply.send("Hello world"); +}); +``` + +A custom limiter function created with `fastify.createRateLimit()` only requires a `FastifyRequest` as the first parameter: + +```ts +const checkRateLimit = fastify.createRateLimit(); +const limit = await checkRateLimit(request); +``` + +The returned `limit` is an object containing the following properties for the `request` passed to `checkRateLimit`. + +- `isAllowed`: if `true`, the request was excluded from rate limiting according to the configured `allowList`. +- `key`: the generated key as returned by the `keyGenerator` function. + +If `isAllowed` is `false` the object also contains these additional properties: + +- `max`: the configured `max` option as a number. If a `max` function was supplied as global option or to `fastify.createRateLimit()`, this property will correspond to the function's return type for the given `request`. +- `timeWindow`: the configured `timeWindow` option in milliseconds. If a function was supplied to `timeWindow`, similar to the `max` property above, this property will be equal to the function's return type. +- `remaining`: the remaining amount of requests before the limit is exceeded. +- `ttl`: the remaining time until the limit will be reset in milliseconds. +- `ttlInSeconds`: `ttl` in seconds. +- `isExceeded`: `true` if the limit was exceeded. +- `isBanned`: `true` if the request was banned according to the `ban` option. + ### Examples of Custom Store These examples show an overview of the `store` feature and you should take inspiration from it and tweak as you need: diff --git a/index.js b/index.js index 050172a..7099450 100644 --- a/index.js +++ b/index.js @@ -125,16 +125,17 @@ async function fastifyRateLimit (fastify, settings) { fastify.decorateRequest(pluginComponent.rateLimitRan, false) + if (!fastify.hasDecorator('createRateLimit')) { + fastify.decorate('createRateLimit', (options) => { + const args = createLimiterArgs(pluginComponent, globalParams, options) + return (req) => applyRateLimit(...args, req) + }) + } + if (!fastify.hasDecorator('rateLimit')) { fastify.decorate('rateLimit', (options) => { - if (typeof options === 'object') { - const newPluginComponent = Object.create(pluginComponent) - const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} }) - newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams) - return rateLimitRequestHandler(newPluginComponent, mergedRateLimitParams) - } - - return rateLimitRequestHandler(pluginComponent, globalParams) + const args = createLimiterArgs(pluginComponent, globalParams, options) + return rateLimitRequestHandler(...args) }) } @@ -186,6 +187,17 @@ function mergeParams (...params) { return result } +function createLimiterArgs (pluginComponent, globalParams, options) { + if (typeof options === 'object') { + const newPluginComponent = Object.create(pluginComponent) + const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} }) + newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams) + return [newPluginComponent, mergedRateLimitParams] + } + + return [pluginComponent, globalParams] +} + function addRouteRateHook (pluginComponent, params, routeOptions) { const hook = params.hook const hookHandler = rateLimitRequestHandler(pluginComponent, params) @@ -198,8 +210,72 @@ function addRouteRateHook (pluginComponent, params, routeOptions) { } } +async function applyRateLimit (pluginComponent, params, req) { + const { store } = pluginComponent + + // Retrieve the key from the generator (the global one or the one defined in the endpoint) + let key = await params.keyGenerator(req) + const groupId = req.routeOptions.config?.rateLimit?.groupId + + if (groupId) { + key += groupId + } + + // Don't apply any rate limiting if in the allow list + if (params.allowList) { + if (typeof params.allowList === 'function') { + if (await params.allowList(req, key)) { + return { + isAllowed: true, + key + } + } + } else if (params.allowList.indexOf(key) !== -1) { + return { + isAllowed: true, + key + } + } + } + + const max = typeof params.max === 'number' ? params.max : await params.max(req, key) + const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key) + let current = 0 + let ttl = 0 + let ttlInSeconds = 0 + + // We increment the rate limit for the current request + try { + const res = await new Promise((resolve, reject) => { + store.incr(key, (err, res) => { + err ? reject(err) : resolve(res) + }, timeWindow, max) + }) + + current = res.current + ttl = res.ttl + ttlInSeconds = Math.ceil(res.ttl / 1000) + } catch (err) { + if (!params.skipOnError) { + throw err + } + } + + return { + isAllowed: false, + key, + max, + timeWindow, + remaining: Math.max(0, max - current), + ttl, + ttlInSeconds, + isExceeded: current > max, + isBanned: params.ban !== -1 && current - max > params.ban + } +} + function rateLimitRequestHandler (pluginComponent, params) { - const { rateLimitRan, store } = pluginComponent + const { rateLimitRan } = pluginComponent return async (req, res) => { if (req[rateLimitRan]) { @@ -208,50 +284,24 @@ function rateLimitRequestHandler (pluginComponent, params) { req[rateLimitRan] = true - // Retrieve the key from the generator (the global one or the one defined in the endpoint) - let key = await params.keyGenerator(req) - const groupId = req.routeOptions.config?.rateLimit?.groupId - if (groupId) { - key += groupId - } - - // Don't apply any rate limiting if in the allow list - if (params.allowList) { - if (typeof params.allowList === 'function') { - if (await params.allowList(req, key)) { - return - } - } else if (params.allowList.indexOf(key) !== -1) { - return - } + const rateLimit = await applyRateLimit(pluginComponent, params, req) + if (rateLimit.isAllowed) { + return } - const max = typeof params.max === 'number' ? params.max : await params.max(req, key) - const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key) - let current = 0 - let ttl = 0 - let ttlInSeconds = 0 - - // We increment the rate limit for the current request - try { - const res = await new Promise((resolve, reject) => { - store.incr(key, (err, res) => { - err ? reject(err) : resolve(res) - }, timeWindow, max) - }) - - current = res.current - ttl = res.ttl - ttlInSeconds = Math.ceil(res.ttl / 1000) - } catch (err) { - if (!params.skipOnError) { - throw err - } - } + const { + key, + max, + remaining, + ttl, + ttlInSeconds, + isExceeded, + isBanned + } = rateLimit - if (current <= max) { + if (!isExceeded) { if (params.addHeadersOnExceeding[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) } - if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, max - current) } + if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, remaining) } if (params.addHeadersOnExceeding[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) } params.onExceeding(req, key) @@ -274,7 +324,7 @@ function rateLimitRequestHandler (pluginComponent, params) { after: format(ttlInSeconds * 1000, true) } - if (params.ban !== -1 && current - max > params.ban) { + if (isBanned) { respCtx.statusCode = 403 respCtx.ban = true params.onBanReach(req, key) diff --git a/test/create-rate-limit.test.js b/test/create-rate-limit.test.js new file mode 100644 index 0000000..9069669 --- /dev/null +++ b/test/create-rate-limit.test.js @@ -0,0 +1,224 @@ +'use strict' + +const { test, mock } = require('node:test') +const Fastify = require('fastify') +const rateLimit = require('../index') + +test('With global rate limit options', async t => { + t.plan(8) + const clock = mock.timers + clock.enable(0) + const fastify = Fastify() + await fastify.register(rateLimit, { + global: false, + max: 2, + timeWindow: 1000 + }) + + const checkRateLimit = fastify.createRateLimit() + + fastify.get('/', async (req, reply) => { + const limit = await checkRateLimit(req) + return limit + }) + + let res + + res = await fastify.inject('/') + + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 1, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: false, + isBanned: false + }) + + res = await fastify.inject('/') + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 0, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: false, + isBanned: false + }) + + res = await fastify.inject('/') + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 0, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: true, + isBanned: false + }) + + clock.tick(1100) + + res = await fastify.inject('/') + + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 1, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: false, + isBanned: false + }) + + clock.reset() +}) + +test('With custom rate limit options', async t => { + t.plan(10) + const clock = mock.timers + clock.enable(0) + const fastify = Fastify() + await fastify.register(rateLimit, { + global: false, + max: 5, + timeWindow: 1000 + }) + + const checkRateLimit = fastify.createRateLimit({ + max: 2, + timeWindow: 1000, + ban: 1 + }) + + fastify.get('/', async (req, reply) => { + const limit = await checkRateLimit(req) + return limit + }) + + let res + + res = await fastify.inject('/') + + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 1, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: false, + isBanned: false + }) + + res = await fastify.inject('/') + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 0, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: false, + isBanned: false + }) + + // should be exceeded now + res = await fastify.inject('/') + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 0, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: true, + isBanned: false + }) + + // should be banned now + res = await fastify.inject('/') + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 0, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: true, + isBanned: true + }) + + clock.tick(1100) + + res = await fastify.inject('/') + + t.assert.deepStrictEqual(res.statusCode, 200) + t.assert.deepStrictEqual(res.json(), { + isAllowed: false, + key: '127.0.0.1', + max: 2, + timeWindow: 1000, + remaining: 1, + ttl: 1000, + ttlInSeconds: 1, + isExceeded: false, + isBanned: false + }) + + clock.reset() +}) + +test('With allow list', async t => { + t.plan(2) + const clock = mock.timers + clock.enable(0) + const fastify = Fastify() + await fastify.register(rateLimit, { + global: false, + max: 5, + timeWindow: 1000 + }) + + const checkRateLimit = fastify.createRateLimit({ + allowList: ['127.0.0.1'], + max: 2, + timeWindow: 1000 + }) + + fastify.get('/', async (req, reply) => { + const limit = await checkRateLimit(req) + return limit + }) + + const res = await fastify.inject('/') + + t.assert.deepStrictEqual(res.statusCode, 200) + + // expect a different return type because isAllowed is true + t.assert.deepStrictEqual(res.json(), { + isAllowed: true, + key: '127.0.0.1' + }) +}) diff --git a/types/index.d.ts b/types/index.d.ts index f30c760..11de667 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -12,6 +12,24 @@ import { declare module 'fastify' { interface FastifyInstance { + createRateLimit(options?: fastifyRateLimit.CreateRateLimitOptions): (req: FastifyRequest) => Promise< + | { + isAllowed: true + key: string + } + | { + isAllowed: false + key: string + max: number + timeWindow: number + remaining: number + ttl: number + ttlInSeconds: number + isExceeded: boolean + isBanned: boolean + } + > + rateLimit< RouteGeneric extends RouteGenericInterface = RouteGenericInterface, ContextConfig = ContextConfigDefault, @@ -89,13 +107,9 @@ declare namespace fastifyRateLimit { 'ratelimit-reset'?: boolean; } - export type RateLimitHook = - | 'onRequest' - | 'preParsing' - | 'preValidation' - | 'preHandler' - - export interface RateLimitOptions { + export interface CreateRateLimitOptions { + store?: FastifyRateLimitStoreCtor; + skipOnError?: boolean; max?: | number | ((req: FastifyRequest, key: string) => number) @@ -105,19 +119,26 @@ declare namespace fastifyRateLimit { | string | ((req: FastifyRequest, key: string) => number) | ((req: FastifyRequest, key: string) => Promise); - hook?: RateLimitHook; - cache?: number; - store?: FastifyRateLimitStoreCtor; /** - * @deprecated Use `allowList` property - */ + * @deprecated Use `allowList` property + */ whitelist?: string[] | ((req: FastifyRequest, key: string) => boolean); allowList?: string[] | ((req: FastifyRequest, key: string) => boolean | Promise); - continueExceeding?: boolean; - skipOnError?: boolean; + keyGenerator?: (req: FastifyRequest) => string | number | Promise; ban?: number; + } + + export type RateLimitHook = + | 'onRequest' + | 'preParsing' + | 'preValidation' + | 'preHandler' + + export interface RateLimitOptions extends CreateRateLimitOptions { + hook?: RateLimitHook; + cache?: number; + continueExceeding?: boolean; onBanReach?: (req: FastifyRequest, key: string) => void; - keyGenerator?: (req: FastifyRequest) => string | number | Promise; groupId?: string; errorResponseBuilder?: ( req: FastifyRequest, diff --git a/types/index.test-d.ts b/types/index.test-d.ts index 1343483..b443e91 100644 --- a/types/index.test-d.ts +++ b/types/index.test-d.ts @@ -9,6 +9,7 @@ import * as http2 from 'node:http2' import IORedis from 'ioredis' import pino from 'pino' import fastifyRateLimit, { + CreateRateLimitOptions, errorResponseBuilderContext, FastifyRateLimitOptions, FastifyRateLimitStore, @@ -217,3 +218,60 @@ appWithCustomLogger.route({ preHandler: appWithCustomLogger.rateLimit({}), handler: () => {}, }) + +const options10: CreateRateLimitOptions = { + store: CustomStore, + skipOnError: true, + max: 0, + timeWindow: 5000, + allowList: ['127.0.0.1'], + keyGenerator: (req: FastifyRequest) => req.ip, + ban: 10 +} + +appWithImplicitHttp.register(fastifyRateLimit, { global: false }) +const checkRateLimit = appWithImplicitHttp.createRateLimit(options10) +appWithImplicitHttp.route({ + method: 'GET', + url: '/', + handler: async (req, _reply) => { + const limit = await checkRateLimit(req) + expectType<{ + isAllowed: true; + key: string; + } | { + isAllowed: false; + key: string; + max: number; + timeWindow: number; + remaining: number; + ttl: number; + ttlInSeconds: number; + isExceeded: boolean; + isBanned: boolean; + }>(limit) + }, +}) + +const options11: CreateRateLimitOptions = { + max: (_req: FastifyRequest, _key: string) => 42, + timeWindow: '10s', + allowList: (_req: FastifyRequest) => true, + keyGenerator: (_req: FastifyRequest) => 42, +} + +const options12: CreateRateLimitOptions = { + max: (_req: FastifyRequest, _key: string) => Promise.resolve(42), + timeWindow: (_req: FastifyRequest, _key: string) => 5000, + allowList: (_req: FastifyRequest) => Promise.resolve(true), + keyGenerator: (_req: FastifyRequest) => Promise.resolve(42), +} + +const options13: CreateRateLimitOptions = { + timeWindow: (_req: FastifyRequest, _key: string) => Promise.resolve(5000), + keyGenerator: (_req: FastifyRequest) => Promise.resolve('key'), +} + +expectType(appWithImplicitHttp.rateLimit(options11)) +expectType(appWithImplicitHttp.rateLimit(options12)) +expectType(appWithImplicitHttp.rateLimit(options13))