Add CSRF middleware (#10051)

ref OUT-Q325-03
This commit is contained in:
Tom Moor
2025-08-31 12:35:35 +02:00
committed by GitHub
parent f614f3dd3f
commit 0a9bd39aac
22 changed files with 335 additions and 65 deletions
+19
View File
@@ -0,0 +1,19 @@
import { CSRF } from "@shared/constants";
import { useCsrfToken } from "~/hooks/useCsrfToken";
/**
* Form component that automatically includes a CSRF token as a hidden input field.
*/
export const Form = ({
children,
...props
}: React.FormHTMLAttributes<HTMLFormElement>) => {
const token = useCsrfToken();
return (
<form {...props}>
{token && <input type="hidden" name={CSRF.fieldName} value={token} />}
{children}
</form>
);
};
+30
View File
@@ -0,0 +1,30 @@
import { CSRF } from "@shared/constants";
import { useState, useEffect } from "react";
import { getCookie } from "tiny-cookie";
/**
* React hook for accessing CSRF tokens in components
*
* @returns The CSRF token string or null if not found
*/
export function useCsrfToken() {
const [token, setToken] = useState<string | null>(null);
useEffect(() => {
const updateToken = () => {
const currentToken = getCookie(CSRF.cookieName);
setToken(currentToken);
};
// Initial load
updateToken();
// Listen for cookie changes (when navigating or refreshing)
const interval = setInterval(updateToken, 1000);
return () => clearInterval(interval);
}, []);
return token;
}
+3 -2
View File
@@ -26,6 +26,7 @@ import { Background } from "./components/Background";
import { Centered } from "./components/Centered";
import { ConnectHeader } from "./components/ConnectHeader";
import { TeamSwitcher } from "./components/TeamSwitcher";
import { Form } from "~/components/primitives/Form";
export default function OAuthAuthorize() {
const team = useCurrentTeam({ rejectOnEmpty: false });
@@ -203,7 +204,7 @@ function Authorize() {
</li>
))}
</ul>
<form
<Form
method="POST"
action="/oauth/authorize"
style={{ width: "100%" }}
@@ -236,7 +237,7 @@ function Authorize() {
{t("Authorize")}
</Button>
</Flex>
</form>
</Form>
</Centered>
</Background>
);
@@ -12,15 +12,17 @@ import { detectLanguage } from "~/utils/language";
import { BackButton } from "./BackButton";
import { Background } from "./Background";
import { Centered } from "./Centered";
import { Form } from "~/components/primitives/Form";
const WorkspaceSetup = ({ onBack }: { onBack?: () => void }) => {
const { t } = useTranslation();
return (
<Background>
<BackButton onBack={onBack} />
<ChangeLanguage locale={detectLanguage()} />
<Centered
as="form"
as={Form}
action="/api/installation.create"
method="POST"
gap={12}
+24 -1
View File
@@ -2,7 +2,7 @@ import retry from "fetch-retry";
import trim from "lodash/trim";
import queryString from "query-string";
import EDITOR_VERSION from "@shared/editor/version";
import { JSONObject } from "@shared/types";
import { JSONObject, Scope } from "@shared/types";
import stores from "~/stores";
import Logger from "./Logger";
import download from "./download";
@@ -20,6 +20,9 @@ import {
UnprocessableEntityError,
UpdateRequiredError,
} from "./errors";
import { getCookie } from "tiny-cookie";
import { CSRF } from "@shared/constants";
import AuthenticationHelper from "@shared/helpers/AuthenticationHelper";
type Options = {
baseUrl?: string;
@@ -105,6 +108,20 @@ class ApiClient {
...options?.headers,
};
// Add CSRF token to headers for mutating requests
const isModifyingRequest = ["POST", "PUT", "PATCH", "DELETE"].includes(
method
);
const canAccessWithReadOnly = AuthenticationHelper.canAccess(path, [
Scope.Read,
]);
if (isModifyingRequest && !canAccessWithReadOnly) {
const csrfToken = getCookie(CSRF.cookieName);
if (csrfToken) {
headerOptions[CSRF.headerName] = csrfToken;
}
}
// for multipart forms or other non JSON requests fetch
// populates the Content-Type without needing to explicitly
// set it.
@@ -213,6 +230,12 @@ class ApiClient {
});
}
if (error.error === "csrf_error") {
throw new AuthorizationError(
"CSRF token invalid, please try reloading."
);
}
throw new AuthorizationError(error.message);
}
+6
View File
@@ -32,6 +32,12 @@ export function AuthorizationError(message = "Authorization error") {
});
}
export function CSRFError(message = "Authorization error") {
return httpErrors(403, message, {
id: "csrf_error",
});
}
export function RateLimitExceededError(
message = "Rate limit exceeded for this operation"
) {
+107
View File
@@ -0,0 +1,107 @@
import type { Next } from "koa";
import { Scope } from "@shared/types";
import env from "@server/env";
import AuthenticationHelper from "@shared/helpers/AuthenticationHelper";
import { AppContext } from "@server/types";
import {
generateRawToken,
bundleToken,
unbundleToken,
} from "@server/utils/csrf";
import { getCookieDomain } from "@shared/utils/domains";
import { CSRF } from "@shared/constants";
import { CSRFError } from "@server/errors";
/**
* Middleware that generates and attaches CSRF tokens for safe methods
*/
export function attachCSRFToken() {
return async function attachCSRFTokenMiddleware(ctx: AppContext, next: Next) {
// Only attach tokens for safe methods that don't mutate state
if (["GET", "HEAD", "OPTIONS"].includes(ctx.method)) {
const raw = generateRawToken(16);
const bundled = bundleToken(raw, env.SECRET_KEY);
// Set cookie that JavaScript can read (not HttpOnly)
ctx.cookies.set(CSRF.cookieName, bundled, {
httpOnly: false,
sameSite: "lax",
domain: getCookieDomain(ctx.request.hostname, env.isCloudHosted),
});
}
await next();
};
}
/**
* Middleware that verifies CSRF tokens for mutating requests
*/
export function verifyCSRFToken() {
/**
* Determines if a request requires CSRF protection
*/
const shouldProtectRequest = (ctx: AppContext): boolean => {
// Skip if not a potentially mutating method
if (["GET", "HEAD", "OPTIONS"].includes(ctx.method)) {
return false;
}
// If not using cookie-based auth, skip CSRF protection
if (!ctx.cookies.get("accessToken")) {
return false;
}
// For API routes, use AuthenticationHelper to determine if the operation is read-only
if (ctx.originalUrl.startsWith("/api/")) {
const canAccessWithReadOnly = AuthenticationHelper.canAccess(ctx.path, [
Scope.Read,
]);
// If it can be accessed with read-only scope, it doesn't need CSRF protection
if (canAccessWithReadOnly) {
return false;
}
}
// Protect all other mutating requests
return true;
};
return async function verifyCSRFTokenMiddleware(ctx: AppContext, next: Next) {
if (!shouldProtectRequest(ctx)) {
await next();
return;
}
// Get token from cookie
const cookieVal = ctx.cookies.get(CSRF.cookieName);
if (!cookieVal) {
throw CSRFError("CSRF token missing from cookie");
}
// Get token from header or form field depending on type
// Access the already-parsed body from koa-body middleware
const inputVal =
ctx.get(CSRF.headerName) || ctx.request.body?.[CSRF.fieldName];
if (!inputVal) {
throw CSRFError("CSRF token missing from request");
}
// Verify both tokens are valid HMAC-signed tokens
const { valid: cookieValid } = unbundleToken(cookieVal, env.SECRET_KEY);
const { valid: inputValid } = unbundleToken(inputVal, env.SECRET_KEY);
if (!cookieValid || !inputValid) {
throw CSRFError("CSRF token invalid or malformed");
}
// Verify tokens match (double-submit check)
if (cookieVal !== inputVal) {
throw CSRFError("CSRF token mismatch");
}
await next();
};
}
+1 -1
View File
@@ -21,7 +21,7 @@ import User from "./User";
import ParanoidModel from "./base/ParanoidModel";
import { SkipChangeset } from "./decorators/Changeset";
import Fix from "./decorators/Fix";
import AuthenticationHelper from "./helpers/AuthenticationHelper";
import AuthenticationHelper from "@shared/helpers/AuthenticationHelper";
import Length from "./validators/Length";
@Table({ tableName: "apiKeys", modelName: "apiKey" })
@@ -1,28 +1,10 @@
/* oxlint-disable @typescript-eslint/no-var-requires */
import find from "lodash/find";
import { Scope } from "@shared/types";
import env from "@server/env";
import Team from "@server/models/Team";
import { Hook, PluginManager } from "@server/utils/PluginManager";
export default class AuthenticationHelper {
/**
* The mapping of method names to their scopes, anything not listed here
* defaults to `Scope.Write`.
*
* - `documents.create` -> `Scope.Create`
* - `documents.list` -> `Scope.Read`
* - `documents.info` -> `Scope.Read`
*/
private static methodToScope = {
create: Scope.Create,
list: Scope.Read,
info: Scope.Read,
search: Scope.Read,
documents: Scope.Read,
export: Scope.Read,
};
/**
* Returns the enabled authentication provider configurations for the current
* installation.
@@ -70,45 +52,4 @@ export default class AuthenticationHelper {
);
});
}
/**
* Returns whether the given path can be accessed with any of the scopes. We
* support scopes in the formats of:
*
* - `/api/namespace.method`
* - `namespace:scope`
* - `scope`
*
* @param path The path to check
* @param scopes The scopes to check
* @returns True if the path can be accessed
*/
public static canAccess = (path: string, scopes: string[]) => {
// strip any query string, this is never used as part of scope matching
path = path.split("?")[0];
const resource = path.split("/").pop() ?? "";
const [namespace, method] = resource.split(".");
return scopes.some((scope) => {
const [scopeNamespace, scopeMethod] = scope.match(/[:\.]/g)
? scope.replace("/api/", "").split(/[:\.]/g)
: ["*", scope];
const isRouteScope = scope.startsWith("/api/");
if (isRouteScope) {
return (
(namespace === scopeNamespace || scopeNamespace === "*") &&
(method === scopeMethod || scopeMethod === "*")
);
}
return (
(namespace === scopeNamespace || scopeNamespace === "*") &&
(scopeMethod === Scope.Write ||
this.methodToScope[method as keyof typeof this.methodToScope] ===
scopeMethod)
);
});
};
}
+1 -1
View File
@@ -20,7 +20,7 @@ import User from "@server/models/User";
import ParanoidModel from "@server/models/base/ParanoidModel";
import { SkipChangeset } from "@server/models/decorators/Changeset";
import Fix from "@server/models/decorators/Fix";
import AuthenticationHelper from "@server/models/helpers/AuthenticationHelper";
import AuthenticationHelper from "@shared/helpers/AuthenticationHelper";
import { hash } from "@server/utils/crypto";
import OAuthClient from "./OAuthClient";
@@ -135,6 +135,7 @@ router.post(
});
const presignedPost = await FileStorage.getPresignedPost(
ctx,
key,
acl,
maxUploadSize,
+2
View File
@@ -6,6 +6,7 @@ import env from "@server/env";
import { NotFoundError } from "@server/errors";
import coalesceBody from "@server/middlewares/coaleseBody";
import requestTracer from "@server/middlewares/requestTracer";
import { verifyCSRFToken } from "@server/middlewares/csrf";
import { AppState, AppContext } from "@server/types";
import { Hook, PluginManager } from "@server/utils/PluginManager";
import apiKeys from "./apiKeys";
@@ -67,6 +68,7 @@ api.use(requestTracer());
api.use(apiResponse());
api.use(apiErrorHandler());
api.use(editor());
api.use(verifyCSRFToken());
// Register plugin API routes before others to allow for overrides
PluginManager.getHooks(Hook.API).forEach((hook) =>
+2
View File
@@ -9,6 +9,7 @@ import coalesceBody from "@server/middlewares/coaleseBody";
import { Collection, Team, View } from "@server/models";
import AuthenticationHelper from "@server/models/helpers/AuthenticationHelper";
import { AppState, AppContext, APIContext } from "@server/types";
import { verifyCSRFToken } from "@server/middlewares/csrf";
const app = new Koa<AppState, AppContext>();
const router = new Router();
@@ -77,6 +78,7 @@ router.get("/redirect", authMiddleware(), async (ctx: APIContext) => {
app.use(bodyParser());
app.use(coalesceBody());
app.use(verifyCSRFToken());
app.use(router.routes());
export default app;
+2
View File
@@ -16,6 +16,7 @@ import { RateLimiterStrategy } from "@server/utils/RateLimiter";
import { OAuthInterface } from "@server/utils/oauth/OAuthInterface";
import oauthErrorHandler from "./middlewares/oauthErrorHandler";
import * as T from "./schema";
import { verifyCSRFToken } from "@server/middlewares/csrf";
const app = new Koa();
const router = new Router();
@@ -127,6 +128,7 @@ router.post(
app.use(requestTracer());
app.use(oauthErrorHandler());
app.use(bodyParser());
app.use(verifyCSRFToken());
app.use(router.routes());
export default app;
+2
View File
@@ -13,6 +13,7 @@ import env from "@server/env";
import Logger from "@server/logging/Logger";
import Metrics from "@server/logging/Metrics";
import csp from "@server/middlewares/csp";
import { attachCSRFToken } from "@server/middlewares/csrf";
import ShutdownHelper, { ShutdownOrder } from "@server/utils/ShutdownHelper";
import { initI18n } from "@server/utils/i18n";
import routes from "../routes";
@@ -45,6 +46,7 @@ export default function init(app: Koa = new Koa(), server?: Server) {
}
app.use(compress());
app.use(attachCSRFToken());
// Monitor server connections
if (server) {
+3
View File
@@ -7,6 +7,7 @@ import { isBase64Url, isInternalUrl } from "@shared/utils/urls";
import env from "@server/env";
import Logger from "@server/logging/Logger";
import fetch, { chromeUserAgent, RequestInit } from "@server/utils/fetch";
import { AppContext } from "@server/types";
export default abstract class BaseStorage {
/** The default number of seconds until a signed URL expires. */
@@ -15,6 +16,7 @@ export default abstract class BaseStorage {
/**
* Returns a presigned post for uploading files to the storage provider.
*
* @param ctx The request context
* @param key The path to store the file at
* @param acl The ACL to use
* @param maxUploadSize The maximum upload size in bytes
@@ -22,6 +24,7 @@ export default abstract class BaseStorage {
* @returns The presigned post object to use on the client (TODO: Abstract away from S3)
*/
public abstract getPresignedPost(
ctx: AppContext,
key: string,
acl: string,
maxUploadSize: number,
+4
View File
@@ -11,9 +11,12 @@ import env from "@server/env";
import { InternalError, ValidationError } from "@server/errors";
import Logger from "@server/logging/Logger";
import BaseStorage from "./BaseStorage";
import { CSRF } from "@shared/constants";
import { AppContext } from "@server/types";
export default class LocalStorage extends BaseStorage {
public async getPresignedPost(
ctx: AppContext,
key: string,
acl: string,
maxUploadSize: number,
@@ -26,6 +29,7 @@ export default class LocalStorage extends BaseStorage {
acl,
maxUploadSize: String(maxUploadSize),
contentType,
[CSRF.fieldName]: ctx.cookies.get(CSRF.cookieName) || "",
},
});
}
+2
View File
@@ -20,6 +20,7 @@ import tmp from "tmp";
import env from "@server/env";
import Logger from "@server/logging/Logger";
import BaseStorage from "./BaseStorage";
import { AppContext } from "@server/types";
export default class S3Storage extends BaseStorage {
constructor() {
@@ -34,6 +35,7 @@ export default class S3Storage extends BaseStorage {
}
public async getPresignedPost(
_ctx: AppContext,
key: string,
acl: string,
maxUploadSize: number,
+55
View File
@@ -0,0 +1,55 @@
import { randomBytes, createHmac } from "crypto";
import { safeEqual } from "./crypto";
/**
* Generates cryptographically secure random bytes
*
* @param size The number of bytes to generate
* @returns A buffer containing random bytes
*/
export const generateRawToken = (size: number): Buffer => randomBytes(size);
/**
* Creates an HMAC-SHA256 signature for a token
*
* @param token The token to sign
* @param secret The secret key for signing
* @returns The HMAC signature as a hex string
*/
export const signToken = (token: Buffer, secret: string): string =>
createHmac("sha256", secret).update(token).digest("hex");
/**
* Bundles a token with its HMAC signature
*
* @param token The raw token
* @param secret The secret key for signing
* @returns A string containing the token and signature separated by a dot
*/
export const bundleToken = (token: Buffer, secret: string): string => {
const sig = signToken(token, secret);
return `${token.toString("hex")}.${sig}`;
};
/**
* Unbundles and verifies a token with its HMAC signature
*
* @param bundled The bundled token string
* @param secret The secret key for verification
* @returns An object indicating validity and the raw token if valid
*/
export const unbundleToken = (
bundled: string,
secret: string
): { valid: boolean; raw?: Buffer } => {
const [hex, sig] = bundled.split(".");
if (!hex || !sig) {
return { valid: false };
}
const token = Buffer.from(hex, "hex");
const expected = signToken(token, secret);
const valid = safeEqual(sig, expected);
return { valid, raw: valid ? token : undefined };
};
+6
View File
@@ -15,6 +15,12 @@ export const Pagination = {
sidebarLimit: 10,
};
export const CSRF = {
cookieName: "csrfToken",
headerName: "x-csrf-token",
fieldName: "_csrf",
};
export const TeamPreferenceDefaults: TeamPreferences = {
[TeamPreference.SeamlessEdit]: true,
[TeamPreference.ViewersCanExport]: true,
+62
View File
@@ -0,0 +1,62 @@
import { Scope } from "../types";
export default class AuthenticationHelper {
/**
* The mapping of method names to their scopes, anything not listed here
* defaults to `Scope.Write`.
*
* - `documents.create` -> `Scope.Create`
* - `documents.list` -> `Scope.Read`
* - `documents.info` -> `Scope.Read`
*/
private static methodToScope = {
create: Scope.Create,
config: Scope.Read,
list: Scope.Read,
info: Scope.Read,
search: Scope.Read,
documents: Scope.Read,
export: Scope.Read,
};
/**
* Returns whether the given path can be accessed with any of the scopes. We
* support scopes in the formats of:
*
* - `/api/namespace.method`
* - `namespace:scope`
* - `scope`
*
* @param path The path to check
* @param scopes The scopes to check
* @returns True if the path can be accessed
*/
public static canAccess = (path: string, scopes: string[]) => {
// strip any query string, this is never used as part of scope matching
path = path.split("?")[0];
const resource = path.split("/").pop() ?? "";
const [namespace, method] = resource.split(".");
return scopes.some((scope) => {
const [scopeNamespace, scopeMethod] = scope.match(/[:\.]/g)
? scope.replace("/api/", "").split(/[:\.]/g)
: ["*", scope];
const isRouteScope = scope.startsWith("/api/");
if (isRouteScope) {
return (
(namespace === scopeNamespace || scopeNamespace === "*") &&
(method === scopeMethod || scopeMethod === "*")
);
}
return (
(namespace === scopeNamespace || scopeNamespace === "*") &&
(scopeMethod === Scope.Write ||
this.methodToScope[method as keyof typeof this.methodToScope] ===
scopeMethod)
);
});
};
}