Skip to content

Commit a988dc3

Browse files
feat: add CSRF middleware (#3018)
Co-authored-by: Marvin Hagemeister <marvin@deno.com>
1 parent 3ad3e6e commit a988dc3

File tree

3 files changed

+215
-0
lines changed

3 files changed

+215
-0
lines changed

src/middlewares/csrf.ts

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import type { Context } from "../context.ts";
2+
import { HttpError } from "../error.ts";
3+
import type { Middleware } from "./mod.ts";
4+
5+
/** Options for {@linkcode csrf}. **/
6+
// deno-lint-ignore no-explicit-any
7+
export interface CsrfOptions<State = any> {
8+
/**
9+
* origin - Specifies the allowed origin(s) for requests.
10+
* - string: A single allowed origin.
11+
* - string[]: static allowed origins.
12+
* - function: A function to determine if an origin is allowed.
13+
*/
14+
origin?:
15+
| string
16+
| string[]
17+
| ((origin: string, context: Context<State>) => boolean);
18+
}
19+
20+
/**
21+
* CSRF Protection Middleware for Fresh.
22+
*
23+
* @param options Options for the CSRF protection middleware.
24+
* @returns The middleware handler function.
25+
*
26+
* @example Basic usage (with defaults)
27+
* ```ts
28+
* const app = new App<State>()
29+
*
30+
* app.use(csrf())
31+
* ```
32+
*
33+
* @example Specifying static origins
34+
* ```ts
35+
* app.use(csrf({ origin: 'https://myapp.example.com' }))
36+
*
37+
* // string[]
38+
* app.use(
39+
* csrf({
40+
* origin: ['https://myapp.example.com', 'http://development.myapp.example.com'],
41+
* })
42+
* )
43+
* ```
44+
*
45+
* @example Specifying more complex origins
46+
* ```ts
47+
* app.use(
48+
* '*',
49+
* csrf({
50+
* origin: (origin) => ['https://myapp.example.com', 'http://development.myapp.example.com'].includes(origin),
51+
* })
52+
* )
53+
* ```
54+
*/
55+
export function csrf<State>(
56+
options?: CsrfOptions,
57+
): Middleware<State> {
58+
const isAllowedOrigin = (
59+
origin: string | null,
60+
ctx: Context<State>,
61+
) => {
62+
if (origin === null) {
63+
return false;
64+
}
65+
66+
const optsOrigin = options?.origin;
67+
68+
if (!optsOrigin) {
69+
return origin === ctx.url.origin;
70+
}
71+
if (typeof optsOrigin === "string") {
72+
return origin === optsOrigin;
73+
}
74+
if (typeof optsOrigin === "function") {
75+
return optsOrigin(origin, ctx);
76+
}
77+
return Array.isArray(optsOrigin) && optsOrigin.includes(origin);
78+
};
79+
80+
return async (ctx) => {
81+
const { method, headers } = ctx.req;
82+
83+
// Safe methods
84+
if (method === "GET" || method === "HEAD" || method === "OPTIONS") {
85+
return await ctx.next();
86+
}
87+
88+
const secFetchSite = headers.get("Sec-Fetch-Site");
89+
const origin = headers.get("origin");
90+
91+
if (secFetchSite !== null) {
92+
if (
93+
secFetchSite === "same-origin" || secFetchSite === "none" ||
94+
isAllowedOrigin(origin, ctx)
95+
) {
96+
return await ctx.next();
97+
}
98+
99+
throw new HttpError(403);
100+
}
101+
102+
// Neither `Sec-Fetch-Site` or `Origin` is set
103+
if (origin === null) {
104+
return await ctx.next();
105+
}
106+
107+
if (isAllowedOrigin(origin, ctx)) {
108+
return await ctx.next();
109+
}
110+
111+
throw new HttpError(403);
112+
};
113+
}

src/middlewares/csrf_test.ts

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import { App } from "../app.ts";
2+
import type { Method } from "../router.ts";
3+
import { csrf, type CsrfOptions } from "./csrf.ts";
4+
import { expect } from "@std/expect";
5+
6+
function testHandler<T>(
7+
options?: CsrfOptions<T>,
8+
): (request: Request) => Promise<Response> {
9+
return new App<T>()
10+
.use(csrf(options))
11+
.all("/", () => new Response("hello"))
12+
.handler();
13+
}
14+
15+
function createTests(trusted: CsrfOptions["origin"]) {
16+
return async function runTest(
17+
method: Method,
18+
expected: 200 | 403,
19+
test: {
20+
secFetchSite?: "cross-site" | "same-origin" | "same-site" | "none";
21+
origin?: string;
22+
} = {},
23+
) {
24+
const handler = testHandler({ origin: trusted });
25+
26+
const headers = new Headers();
27+
if (test.secFetchSite) headers.append("Sec-Fetch-Site", test.secFetchSite);
28+
if (test.origin) {
29+
headers.append("Origin", test.origin);
30+
}
31+
32+
const res = await handler(
33+
new Request("https://example.com", { method, headers }),
34+
);
35+
await res.body?.cancel();
36+
expect(res.status).toEqual(expected);
37+
};
38+
}
39+
40+
Deno.test("CSRF - allow GET/HEAD/OPTIONS", async () => {
41+
const runTest = createTests("https://example.com");
42+
43+
await runTest("GET", 200);
44+
await runTest("HEAD", 200);
45+
await runTest("OPTIONS", 200);
46+
});
47+
48+
Deno.test("CSRF - Sec-Fetch-Site", async () => {
49+
const runTest = createTests("https://example.com");
50+
51+
await runTest("POST", 200, { secFetchSite: "same-origin" });
52+
await runTest("POST", 200, { secFetchSite: "none" });
53+
await runTest("POST", 403, { secFetchSite: "cross-site" });
54+
await runTest("POST", 403, { secFetchSite: "same-site" });
55+
56+
await runTest("POST", 200);
57+
await runTest("POST", 200, { origin: "https://example.com" });
58+
await runTest("POST", 403, { origin: "https://attacker.example.com" });
59+
await runTest("POST", 403, { origin: "null" });
60+
61+
await runTest("GET", 200, { secFetchSite: "cross-site" });
62+
await runTest("HEAD", 200, { secFetchSite: "cross-site" });
63+
await runTest("OPTIONS", 200, { secFetchSite: "cross-site" });
64+
await runTest("PUT", 403, { secFetchSite: "cross-site" });
65+
});
66+
67+
Deno.test("CSRF - cross origin", async () => {
68+
const runTest = createTests("https://trusted.example.com");
69+
70+
await runTest("POST", 200, { origin: "https://trusted.example.com" });
71+
await runTest("POST", 200, {
72+
origin: "https://trusted.example.com",
73+
secFetchSite: "cross-site",
74+
});
75+
76+
await runTest("POST", 403, { origin: "https://attacker.example.com" });
77+
await runTest("POST", 403, {
78+
origin: "https://attacker.example.com",
79+
secFetchSite: "cross-site",
80+
});
81+
});
82+
83+
Deno.test("CSRF - array origin", async () => {
84+
const runTest = createTests(
85+
["https://example.com", "https://trusted.example.com"],
86+
);
87+
88+
await runTest("POST", 200, { origin: "https://trusted.example.com" });
89+
await runTest("POST", 200, { origin: "https://example.com" });
90+
await runTest("POST", 403, { origin: "https://foo.example.com" });
91+
});
92+
93+
Deno.test("CSRF - function origin", async () => {
94+
const runTest = createTests(
95+
(origin) => origin === "https://example.com",
96+
);
97+
98+
await runTest("POST", 200, { origin: "https://example.com" });
99+
await runTest("POST", 403, { origin: "https://trusted.example.com" });
100+
await runTest("POST", 403, { origin: "https://foo.example.com" });
101+
});

src/mod.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export {
1111
export type { LayoutConfig, Lazy, MaybeLazy, RouteConfig } from "./types.ts";
1212
export type { Middleware, MiddlewareFn } from "./middlewares/mod.ts";
1313
export { staticFiles } from "./middlewares/static_files.ts";
14+
export { csrf, type CsrfOptions } from "./middlewares/csrf.ts";
1415
export { cors, type CORSOptions } from "./middlewares/cors.ts";
1516
export type { FreshConfig, ResolvedFreshConfig } from "./config.ts";
1617
export type { Context, FreshContext, Island } from "./context.ts";

0 commit comments

Comments
 (0)