Skip to content

Commit 75b1ef5

Browse files
Add origin checks for UI route submissions (#14708)
Co-authored-by: Jacob Ebey <[email protected]>
1 parent c05ef93 commit 75b1ef5

File tree

10 files changed

+174
-6
lines changed

10 files changed

+174
-6
lines changed

.changeset/spotty-masks-beg.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@react-router/dev": minor
3+
"react-router": minor
4+
---
5+
6+
Add additional layer of CSRF protection by rejecting submissions to UI routes from external origins. If you need to permit access to specific external origins, you can specify them in the `react-router.config.ts` config `allowedActionOrigins` field.

integration/vite-presets-test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ test.describe("Vite / presets", async () => {
238238
"serverBundles",
239239
"serverModuleFormat",
240240
"ssr",
241+
"allowedActionOrigins",
241242
"unstable_routeConfig",
242243
]);
243244

packages/react-router-dev/config/config.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ export type ReactRouterConfig = {
211211
* SPA without server-rendering. Default's to `true`.
212212
*/
213213
ssr?: boolean;
214+
215+
/**
216+
* The allowed origins for actions / mutations. Does not apply to routes
217+
* without a component. micromatch glob patterns are supported.
218+
*/
219+
allowedActionOrigins?: string[];
214220
};
215221

216222
export type ResolvedReactRouterConfig = Readonly<{
@@ -277,6 +283,11 @@ export type ResolvedReactRouterConfig = Readonly<{
277283
* SPA without server-rendering. Default's to `true`.
278284
*/
279285
ssr: boolean;
286+
/**
287+
* The allowed origins for actions / mutations. Does not apply to routes
288+
* without a component. micromatch glob patterns are supported.
289+
*/
290+
allowedActionOrigins: string[] | false;
280291
/**
281292
* The resolved array of route config entries exported from `routes.ts`
282293
*/
@@ -645,6 +656,8 @@ async function resolveConfig({
645656
userAndPresetConfigs.future?.v8_viteEnvironmentApi ?? false,
646657
};
647658

659+
let allowedActionOrigins = userAndPresetConfigs.allowedActionOrigins ?? false;
660+
648661
let reactRouterConfig: ResolvedReactRouterConfig = deepFreeze({
649662
appDirectory,
650663
basename,
@@ -658,6 +671,7 @@ async function resolveConfig({
658671
serverBundles,
659672
serverModuleFormat,
660673
ssr,
674+
allowedActionOrigins,
661675
unstable_routeConfig: routeConfig,
662676
} satisfies ResolvedReactRouterConfig);
663677

packages/react-router-dev/typegen/generate.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export function generateServerBuild(ctx: Context): VirtualFile {
4848
export const routeDiscovery: ServerBuild["routeDiscovery"];
4949
export const routes: ServerBuild["routes"];
5050
export const ssr: ServerBuild["ssr"];
51+
export const allowedActionOrigins: ServerBuild["allowedActionOrigins"];
5152
export const unstable_getCriticalCss: ServerBuild["unstable_getCriticalCss"];
5253
}
5354
`;

packages/react-router-dev/vite/plugin.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,9 @@ export const reactRouterVitePlugin: ReactRouterVitePlugin = () => {
871871
}
872872
`
873873
: ""
874-
}`;
874+
}
875+
export const allowedActionOrigins = ${JSON.stringify(ctx.reactRouterConfig.allowedActionOrigins)};
876+
`;
875877
};
876878

877879
let loadViteManifest = async (directory: string) => {
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
export function throwIfPotentialCSRFAttack(
2+
headers: Headers,
3+
allowedActionOrigins: string[] | undefined,
4+
) {
5+
let originHeader = headers.get("origin");
6+
let originDomain =
7+
typeof originHeader === "string" && originHeader !== "null"
8+
? new URL(originHeader).host
9+
: originHeader;
10+
let host = parseHostHeader(headers);
11+
12+
if (originDomain && (!host || originDomain !== host.value)) {
13+
if (!isAllowedOrigin(originDomain, allowedActionOrigins)) {
14+
if (host) {
15+
// This seems to be an CSRF attack. We should not proceed with the action.
16+
throw new Error(
17+
`${host.type} header does not match \`origin\` header from a forwarded ` +
18+
`action request. Aborting the action.`,
19+
);
20+
} else {
21+
// This is an attack. We should not proceed with the action.
22+
throw new Error(
23+
"`x-forwarded-host` or `host` headers are not provided. One of these " +
24+
"is needed to compare the `origin` header from a forwarded action " +
25+
"request. Aborting the action.",
26+
);
27+
}
28+
}
29+
}
30+
}
31+
32+
// Implementation of micromatch by Next.js https://github.com/vercel/next.js/blob/ea927b583d24f42e538001bf13370e38c91d17bf/packages/next/src/server/app-render/csrf-protection.ts#L6
33+
function matchWildcardDomain(domain: string, pattern: string) {
34+
const domainParts = domain.split(".");
35+
const patternParts = pattern.split(".");
36+
37+
if (patternParts.length < 1) {
38+
// pattern is empty and therefore invalid to match against
39+
return false;
40+
}
41+
42+
if (domainParts.length < patternParts.length) {
43+
// domain has too few segments and thus cannot match
44+
return false;
45+
}
46+
47+
// Prevent wildcards from matching entire domains (e.g. '**' or '*.com')
48+
// This ensures wildcards can only match subdomains, not the main domain
49+
if (
50+
patternParts.length === 1 &&
51+
(patternParts[0] === "*" || patternParts[0] === "**")
52+
) {
53+
return false;
54+
}
55+
56+
while (patternParts.length) {
57+
const patternPart = patternParts.pop();
58+
const domainPart = domainParts.pop();
59+
60+
switch (patternPart) {
61+
case "": {
62+
// invalid pattern. pattern segments must be non empty
63+
return false;
64+
}
65+
case "*": {
66+
// wildcard matches anything so we continue if the domain part is non-empty
67+
if (domainPart) {
68+
continue;
69+
} else {
70+
return false;
71+
}
72+
}
73+
case "**": {
74+
// if this is not the last item in the pattern the pattern is invalid
75+
if (patternParts.length > 0) {
76+
return false;
77+
}
78+
// recursive wildcard matches anything so we terminate here if the domain part is non empty
79+
return domainPart !== undefined;
80+
}
81+
case undefined:
82+
default: {
83+
if (domainPart !== patternPart) {
84+
return false;
85+
}
86+
}
87+
}
88+
}
89+
90+
// We exhausted the pattern. If we also exhausted the domain we have a match
91+
return domainParts.length === 0;
92+
}
93+
94+
function isAllowedOrigin(
95+
originDomain: string,
96+
allowedActionOrigins: string[] | undefined = [],
97+
) {
98+
return allowedActionOrigins.some(
99+
(allowedOrigin) =>
100+
allowedOrigin &&
101+
(allowedOrigin === originDomain ||
102+
matchWildcardDomain(originDomain, allowedOrigin)),
103+
);
104+
}
105+
106+
function parseHostHeader(headers: Headers) {
107+
let forwardedHostHeader = headers.get("x-forwarded-host");
108+
let forwardedHostValue = forwardedHostHeader?.split(",")[0]?.trim();
109+
let hostHeader = headers.get("host");
110+
111+
return forwardedHostValue
112+
? {
113+
type: "x-forwarded-host",
114+
value: forwardedHostValue,
115+
}
116+
: hostHeader
117+
? {
118+
type: "host",
119+
value: hostHeader,
120+
}
121+
: undefined;
122+
}

packages/react-router/lib/rsc/server.rsc.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import {
3838
} from "../router/utils";
3939
import { getDocumentHeadersImpl } from "../server-runtime/headers";
4040
import { SINGLE_FETCH_REDIRECT_STATUS } from "../dom/ssr/single-fetch";
41+
import { throwIfPotentialCSRFAttack } from "../actions";
4142
import type { RouteMatch, RouteObject } from "../context";
4243
import invariant from "../server-runtime/invariant";
4344

@@ -331,6 +332,7 @@ export type LoadServerActionFunction = (id: string) => Promise<Function>;
331332
* @category RSC
332333
* @mode data
333334
* @param opts Options
335+
* @param opts.allowedActionOrigins Origin patterns that are allowed to execute actions.
334336
* @param opts.basename The basename to use when matching the request.
335337
* @param opts.createTemporaryReferenceSet A function that returns a temporary
336338
* reference set for the request, used to track temporary references in the [RSC](https://react.dev/reference/rsc/server-components)
@@ -361,6 +363,7 @@ export type LoadServerActionFunction = (id: string) => Promise<Function>;
361363
* data for hydration.
362364
*/
363365
export async function matchRSCServerRequest({
366+
allowedActionOrigins,
364367
createTemporaryReferenceSet,
365368
basename,
366369
decodeReply,
@@ -373,6 +376,7 @@ export async function matchRSCServerRequest({
373376
routes,
374377
generateResponse,
375378
}: {
379+
allowedActionOrigins?: string[];
376380
createTemporaryReferenceSet: () => unknown;
377381
basename?: string;
378382
decodeReply?: DecodeReplyFunction;
@@ -477,6 +481,7 @@ export async function matchRSCServerRequest({
477481
onError,
478482
generateResponse,
479483
temporaryReferences,
484+
allowedActionOrigins,
480485
);
481486
// The front end uses this to know whether a 4xx/5xx status came from app code
482487
// or never reached the origin server
@@ -754,6 +759,7 @@ async function generateRenderResponse(
754759
},
755760
) => Response,
756761
temporaryReferences: unknown,
762+
allowedActionOrigins: string[] | undefined,
757763
): Promise<Response> {
758764
// If this is a RR submission, we just want the `actionData` but don't want
759765
// to call any loaders or render any components back in the response - that
@@ -799,6 +805,8 @@ async function generateRenderResponse(
799805
let formState: unknown;
800806
let skipRevalidation = false;
801807
if (request.method === "POST") {
808+
throwIfPotentialCSRFAttack(request.headers, allowedActionOrigins);
809+
802810
ctx.runningAction = true;
803811
let result = await processServerAction(
804812
request,

packages/react-router/lib/server-runtime/build.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@ import type {
1212
import type { ServerRouteManifest } from "./routes";
1313
import type { AppLoadContext } from "./data";
1414
import type { MiddlewareEnabled } from "../types/future";
15-
import type {
16-
unstable_InstrumentRequestHandlerFunction,
17-
unstable_InstrumentRouteFunction,
18-
unstable_ServerInstrumentation,
19-
} from "../router/instrumentation";
15+
import type { unstable_ServerInstrumentation } from "../router/instrumentation";
2016

2117
type OptionalCriticalCss = CriticalCss | undefined;
2218

@@ -46,6 +42,7 @@ export interface ServerBuild {
4642
mode: "lazy" | "initial";
4743
manifestPath: string;
4844
};
45+
allowedActionOrigins?: string[] | false;
4946
}
5047

5148
export interface HandleDocumentRequestFunction {

packages/react-router/lib/server-runtime/server.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import type { MiddlewareEnabled } from "../types/future";
3838
import { getManifestPath } from "../dom/ssr/fog-of-war";
3939
import type { unstable_InstrumentRequestHandlerFunction } from "../router/instrumentation";
4040
import { instrumentHandler } from "../router/instrumentation";
41+
import { throwIfPotentialCSRFAttack } from "../actions";
4142

4243
export type RequestHandler = (
4344
request: Request,
@@ -481,6 +482,14 @@ async function handleDocumentRequest(
481482
criticalCss?: CriticalCss,
482483
) {
483484
try {
485+
if (request.method === "POST") {
486+
throwIfPotentialCSRFAttack(
487+
request.headers,
488+
Array.isArray(build.allowedActionOrigins)
489+
? build.allowedActionOrigins
490+
: [],
491+
);
492+
}
484493
let result = await staticHandler.query(request, {
485494
requestContext: loadContext,
486495
generateMiddlewareResponse: build.future.v8_middleware

packages/react-router/lib/server-runtime/single-fetch.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { sanitizeError, sanitizeErrors } from "./errors";
2323
import { ServerMode } from "./mode";
2424
import { getDocumentHeaders } from "./headers";
2525
import type { ServerBuild } from "./build";
26+
import { throwIfPotentialCSRFAttack } from "../actions";
2627

2728
// Add 304 for server side - that is not included in the client side logic
2829
// because the browser should fill those responses with the cached data
@@ -42,6 +43,13 @@ export async function singleFetchAction(
4243
handleError: (err: unknown) => void,
4344
): Promise<Response> {
4445
try {
46+
throwIfPotentialCSRFAttack(
47+
request.headers,
48+
Array.isArray(build.allowedActionOrigins)
49+
? build.allowedActionOrigins
50+
: [],
51+
);
52+
4553
let handlerRequest = new Request(handlerUrl, {
4654
method: request.method,
4755
body: request.body,

0 commit comments

Comments
 (0)