Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 55 additions & 51 deletions src/services/plugin.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import Elysia from 'elysia'
import Elysia from "elysia";

import { defaultOptions } from '../constants/defaultOptions'
import { DefaultContext } from './defaultContext'
import { defaultOptions } from "../constants/defaultOptions";
import { DefaultContext } from "./defaultContext";

import { logger } from './logger'
import { logger } from "./logger";

import type { Options } from '../@types/Options'
import type { Options } from "../@types/Options";

export const plugin = function rateLimitPlugin(userOptions?: Partial<Options>) {
const options: Options = {
...defaultOptions,
...userOptions,
context: userOptions?.context ?? new DefaultContext(),
}
};

options.context.init(options)
options.context.init(options);

// NOTE:
// do not make plugin to return async
// otherwise request will be triggered twice
return function registerRateLimitPlugin(app: Elysia) {
const plugin = new Elysia({
name: 'elysia-rate-limit',
name: "elysia-rate-limit",
seed: options.max,
})
});

plugin.onBeforeHandle(
{ as: options.scoping },
Expand All @@ -42,8 +42,10 @@ export const plugin = function rateLimitPlugin(userOptions?: Partial<Options>) {
qi,
...rest
}) {
let clientKey: string | undefined
const enhancedRequest = Object.defineProperty(request, 'cookie', {value: cookie});
let clientKey: string | undefined;
const enhancedRequest = Object.defineProperty(request, "cookie", {
value: cookie,
});

/**
* if a skip option has two parameters,
Expand All @@ -56,7 +58,7 @@ export const plugin = function rateLimitPlugin(userOptions?: Partial<Options>) {
enhancedRequest,
options.injectServer?.() ?? app.server,
rest
)
);

// if decided to skip, then do nothing and let the app continue
if ((await options.skip(enhancedRequest, clientKey)) === false) {
Expand All @@ -67,89 +69,89 @@ export const plugin = function rateLimitPlugin(userOptions?: Partial<Options>) {
*/
if (options.skip.length < 2)
clientKey = await options.generator(
completeRequest,
enhancedRequest,
options.injectServer?.() ?? app.server,
rest
)
);

const { count, nextReset } = await options.context.increment(
// biome-ignore lint/style/noNonNullAssertion: <explanation>
clientKey!
)
);

const payload = {
limit: options.max,
current: count,
remaining: Math.max(options.max - count, 0),
nextReset,
}
};

// set standard headers
const reset = Math.max(
0,
Math.ceil((nextReset.getTime() - Date.now()) / 1000)
)
);

const builtHeaders: Record<string, string> = {
'RateLimit-Limit': String(options.max),
'RateLimit-Remaining': String(payload.remaining),
'RateLimit-Reset': String(reset),
}
"RateLimit-Limit": String(options.max),
"RateLimit-Remaining": String(payload.remaining),
"RateLimit-Reset": String(reset),
};

// reject if limit were reached
if (payload.current >= payload.limit + 1) {
logger(
'plugin',
'rate limit exceeded for clientKey: %s (resetting in %d seconds)',
"plugin",
"rate limit exceeded for clientKey: %s (resetting in %d seconds)",
clientKey,
reset
)
);

builtHeaders['Retry-After'] = String(
builtHeaders["Retry-After"] = String(
Math.ceil(options.duration / 1000)
)
);

if (options.errorResponse instanceof Error)
throw options.errorResponse
throw options.errorResponse;
if (options.errorResponse instanceof Response) {
// duplicate the response to avoid mutation
const clonedResponse = options.errorResponse.clone()
const clonedResponse = options.errorResponse.clone();

// append headers
if (options.headers)
for (const [key, value] of Object.entries(builtHeaders))
clonedResponse.headers.set(key, value)
clonedResponse.headers.set(key, value);

return clonedResponse
return clonedResponse;
}

// append headers
if (options.headers)
for (const [key, value] of Object.entries(builtHeaders))
set.headers[key] = value
set.headers[key] = value;

// set default status code
set.status = 429
set.status = 429;

return options.errorResponse
return options.errorResponse;
}

// append headers
if (options.headers)
for (const [key, value] of Object.entries(builtHeaders))
set.headers[key] = value
set.headers[key] = value;

logger(
'plugin',
'clientKey %s passed through with %d/%d request used (resetting in %d seconds)',
"plugin",
"clientKey %s passed through with %d/%d request used (resetting in %d seconds)",
clientKey,
options.max - payload.remaining,
options.max,
reset
)
);
}
}
)
);

plugin.onError(
{ as: options.scoping },
Expand All @@ -170,28 +172,30 @@ export const plugin = function rateLimitPlugin(userOptions?: Partial<Options>) {
...rest
}) {
if (!options.countFailedRequest) {
const enhancedRequest = Object.defineProperty(request, 'cookie', {value: cookie});
const enhancedRequest = Object.defineProperty(request, "cookie", {
value: cookie,
});
const clientKey = await options.generator(
enhancedRequest,
options.injectServer?.() ?? app.server,
rest
)
);

logger(
'plugin',
'request failed for clientKey: %s, refunding',
"plugin",
"request failed for clientKey: %s, refunding",
clientKey
)
await options.context.decrement(clientKey)
);
await options.context.decrement(clientKey);
}
}
)
);

plugin.onStop(async function onStopRateLimitHandler() {
logger('plugin', 'kill signal received')
await options.context.kill()
})
logger("plugin", "kill signal received");
await options.context.kill();
});

return app.use(plugin)
}
}
return app.use(plugin);
};
};
Loading