fix: Use verified JWT for rate limiting (#12114)

* fix: Use verified JWT for rate limiting

* PR feedback

* Prefer guards
This commit is contained in:
Tom Moor
2026-04-20 06:19:39 -04:00
committed by GitHub
parent 06d5969099
commit 1b91a295e1
5 changed files with 253 additions and 146 deletions
+11
View File
@@ -191,6 +191,17 @@ class Logger {
* @returns The sanitized data
*/
private sanitize = <T>(input: T, level = 0): T => {
// Errors have non-enumerable message/stack which are dropped by spreads
// and JSON serialization, so convert them to a plain object up-front.
if (input instanceof Error) {
// oxlint-disable-next-line @typescript-eslint/no-explicit-any
return {
name: input.name,
message: input.message,
stack: input.stack,
} as any as T;
}
// Short circuit if we're not in production to enable easier debugging
if (!env.isProduction) {
return input;
+154 -132
View File
@@ -1,7 +1,7 @@
import JWT from "jsonwebtoken";
import type { Context } from "koa";
import env from "@server/env";
import { ApiKey } from "@server/models";
import * as jwtUtils from "@server/utils/jwt";
import RateLimiter from "@server/utils/RateLimiter";
import { defaultRateLimiter, rateLimiter } from "./rateLimiter";
@@ -9,25 +9,22 @@ describe("rateLimiter middleware", () => {
const originalRateLimiterEnabled = env.RATE_LIMITER_ENABLED;
beforeEach(() => {
// Enable rate limiter for tests
env.RATE_LIMITER_ENABLED = true;
// Clear the rate limiter map before each test
RateLimiter.rateLimiterMap.clear();
});
afterEach(() => {
// Restore original value
env.RATE_LIMITER_ENABLED = originalRateLimiterEnabled;
jest.restoreAllMocks();
});
it("should register and enforce custom rate limiter with matching paths (no mountPath)", async () => {
const customConfig = { duration: 60, requests: 5 };
// Simulate the rateLimiter middleware registration
const registerMiddleware = rateLimiter(customConfig);
const mockCtx = {
path: "/documents.export",
mountPath: undefined, // No mount path
mountPath: undefined,
ip: "127.0.0.1",
set: jest.fn(),
request: {},
@@ -35,14 +32,10 @@ describe("rateLimiter middleware", () => {
await registerMiddleware(mockCtx, jest.fn());
// Check if the rate limiter was registered
const registeredPath = "/documents.export";
expect(RateLimiter.hasRateLimiter(registeredPath)).toBe(true);
// Simulate the defaultRateLimiter middleware lookup
const limiter = RateLimiter.getRateLimiter(mockCtx.path);
// Verify that the custom rate limiter is found
expect(limiter).not.toBe(RateLimiter.defaultRateLimiter);
expect(limiter.points).toBe(5);
});
@@ -50,24 +43,8 @@ describe("rateLimiter middleware", () => {
it("should register and enforce custom rate limiter with matching paths (with mountPath)", async () => {
const customConfig = { duration: 60, requests: 5 };
// Simulate the rateLimiter middleware registration with mountPath
const registerMiddleware = rateLimiter(customConfig);
const mockCtxRegister = {
path: "/documents.export",
mountPath: "/api", // This is set when router is mounted
ip: "127.0.0.1",
set: jest.fn(),
request: {},
} as unknown as Context;
await registerMiddleware(mockCtxRegister, jest.fn());
// The rateLimiter middleware constructs fullPath = mountPath + path
const registrationPath = "/api/documents.export";
expect(RateLimiter.hasRateLimiter(registrationPath)).toBe(true);
// Now check what defaultRateLimiter will use (after fix, should use fullPath)
const mockCtxEnforce = {
path: "/documents.export",
mountPath: "/api",
ip: "127.0.0.1",
@@ -75,103 +52,31 @@ describe("rateLimiter middleware", () => {
request: {},
} as unknown as Context;
// Construct fullPath the same way as the fixed defaultRateLimiter should
const fullPath = `${mockCtxEnforce.mountPath ?? ""}${mockCtxEnforce.path}`;
expect(fullPath).toBe("/api/documents.export");
await registerMiddleware(mockCtxRegister, jest.fn());
// After the fix, hasRateLimiter should find the custom rate limiter
expect(RateLimiter.hasRateLimiter(fullPath)).toBe(true);
const registrationPath = "/api/documents.export";
expect(RateLimiter.hasRateLimiter(registrationPath)).toBe(true);
// And the custom rate limiter should be used
const limiter = RateLimiter.getRateLimiter(fullPath);
const limiter = RateLimiter.getRateLimiter(registrationPath);
expect(limiter).not.toBe(RateLimiter.defaultRateLimiter);
expect(limiter.points).toBe(5);
});
it("should use default rate limiter when no custom rate limiter is registered", async () => {
const mockCtx = {
path: "/some/random/path",
mountPath: undefined,
ip: "127.0.0.1",
set: jest.fn(),
request: {},
} as unknown as Context;
const fullPath = `${mockCtx.mountPath ?? ""}${mockCtx.path}`;
// No custom rate limiter registered
const fullPath = "/some/random/path";
expect(RateLimiter.hasRateLimiter(fullPath)).toBe(false);
// Should use default rate limiter
const limiter = RateLimiter.getRateLimiter(fullPath);
expect(limiter).toBe(RateLimiter.defaultRateLimiter);
});
it("should construct correct consume key with fullPath when custom rate limiter exists", async () => {
const customConfig = { duration: 60, requests: 5 };
// Register with mountPath
const registerMiddleware = rateLimiter(customConfig);
const mockCtxRegister = {
path: "/documents.export",
mountPath: "/api",
ip: "127.0.0.1",
set: jest.fn(),
request: {},
} as unknown as Context;
await registerMiddleware(mockCtxRegister, jest.fn());
// Check what key defaultRateLimiter will use (after fix)
const mockCtxEnforce = {
path: "/documents.export",
mountPath: "/api",
ip: "127.0.0.1",
set: jest.fn(),
request: {},
} as unknown as Context;
const fullPath = `${mockCtxEnforce.mountPath ?? ""}${mockCtxEnforce.path}`;
// After fix, the key should include the full path
const key = RateLimiter.hasRateLimiter(fullPath)
? `${fullPath}:${mockCtxEnforce.ip}`
: `${mockCtxEnforce.ip}`;
// Expected key format: "/api/documents.export:127.0.0.1"
expect(key).toBe("/api/documents.export:127.0.0.1");
});
describe("user-based rate limiting", () => {
it("should use user ID from JWT when authenticated", async () => {
const userId = "test-user-id-123";
const token = JWT.sign({ id: userId, type: "session" }, "secret");
describe("cache-keyed rate limiting", () => {
it("falls back to IP when no token is present", async () => {
const middleware = defaultRateLimiter();
const consumeSpy = jest.spyOn(RateLimiter.defaultRateLimiter, "consume");
const ip = "192.168.1.1";
const mockCtx = {
path: "/some/path",
mountPath: undefined,
ip,
set: jest.fn(),
request: {
get: () => `Bearer ${token}`,
},
cookies: {
get: () => undefined,
},
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(consumeSpy).toHaveBeenCalledWith(`user:${userId}:${ip}`);
consumeSpy.mockRestore();
});
it("should fall back to IP when no token is provided", async () => {
const middleware = defaultRateLimiter();
const consumeSpy = jest.spyOn(RateLimiter.defaultRateLimiter, "consume");
const consumeSpy = jest
.spyOn(RateLimiter.defaultRateLimiter, "consume")
.mockResolvedValue({} as never);
const cacheSpy = jest.spyOn(RateLimiter, "getCachedUserIdForToken");
const mockCtx = {
path: "/some/path",
@@ -183,63 +88,180 @@ describe("rateLimiter middleware", () => {
body: {},
query: {},
},
cookies: {
get: () => undefined,
},
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(cacheSpy).not.toHaveBeenCalled();
expect(consumeSpy).toHaveBeenCalledWith("192.168.1.1");
consumeSpy.mockRestore();
});
it("should fall back to IP for API key tokens", async () => {
it("short-circuits to IP for API key tokens without hitting Redis or JWT verify", async () => {
const apiKeyToken = `${ApiKey.prefix}${"a".repeat(38)}`;
const middleware = defaultRateLimiter();
const consumeSpy = jest.spyOn(RateLimiter.defaultRateLimiter, "consume");
const consumeSpy = jest
.spyOn(RateLimiter.defaultRateLimiter, "consume")
.mockResolvedValue({} as never);
const cacheReadSpy = jest.spyOn(RateLimiter, "getCachedUserIdForToken");
const verifySpy = jest.spyOn(jwtUtils, "getUserForJWT");
const mockCtx = {
path: "/some/path",
mountPath: undefined,
ip: "192.168.1.1",
set: jest.fn(),
request: {
get: () => `Bearer ${apiKeyToken}`,
},
cookies: {
get: () => undefined,
},
request: { get: () => `Bearer ${apiKeyToken}` },
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(cacheReadSpy).not.toHaveBeenCalled();
expect(verifySpy).not.toHaveBeenCalled();
expect(consumeSpy).toHaveBeenCalledWith("192.168.1.1");
consumeSpy.mockRestore();
});
it("should fall back to IP when JWT is malformed", async () => {
it("falls back to IP when token fails verification (forged or expired)", async () => {
const middleware = defaultRateLimiter();
const consumeSpy = jest.spyOn(RateLimiter.defaultRateLimiter, "consume");
const consumeSpy = jest
.spyOn(RateLimiter.defaultRateLimiter, "consume")
.mockResolvedValue({} as never);
jest
.spyOn(RateLimiter, "getCachedUserIdForToken")
.mockResolvedValue(null);
const cacheWriteSpy = jest
.spyOn(RateLimiter, "cacheUserForToken")
.mockResolvedValue();
jest
.spyOn(jwtUtils, "getUserForJWT")
.mockRejectedValue(new Error("invalid token"));
const mockCtx = {
path: "/some/path",
mountPath: undefined,
ip: "192.168.1.1",
set: jest.fn(),
request: {
get: () => "Bearer invalid-token",
},
cookies: {
get: () => undefined,
},
request: { get: () => "Bearer forged-or-unknown-token" },
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(consumeSpy).toHaveBeenCalledWith("192.168.1.1");
consumeSpy.mockRestore();
expect(cacheWriteSpy).not.toHaveBeenCalled();
});
it("verifies and caches the user on cache miss, then keys by user", async () => {
const middleware = defaultRateLimiter();
const consumeSpy = jest
.spyOn(RateLimiter.defaultRateLimiter, "consume")
.mockResolvedValue({} as never);
jest
.spyOn(RateLimiter, "getCachedUserIdForToken")
.mockResolvedValue(null);
const cacheWriteSpy = jest
.spyOn(RateLimiter, "cacheUserForToken")
.mockResolvedValue();
jest.spyOn(jwtUtils, "getUserForJWT").mockResolvedValue({
user: { id: "user-abc" },
} as never);
const mockCtx = {
path: "/some/path",
mountPath: undefined,
ip: "192.168.1.1",
set: jest.fn(),
request: { get: () => "Bearer valid-token" },
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(cacheWriteSpy).toHaveBeenCalledWith("valid-token", "user-abc");
expect(consumeSpy).toHaveBeenCalledWith("user-abc");
});
it("keys on user id when token is in cache without re-verifying", async () => {
const middleware = defaultRateLimiter();
const consumeSpy = jest
.spyOn(RateLimiter.defaultRateLimiter, "consume")
.mockResolvedValue({} as never);
jest
.spyOn(RateLimiter, "getCachedUserIdForToken")
.mockResolvedValue("user-abc");
const verifySpy = jest.spyOn(jwtUtils, "getUserForJWT");
const mockCtx = {
path: "/some/path",
mountPath: undefined,
ip: "192.168.1.1",
set: jest.fn(),
request: { get: () => "Bearer verified-token" },
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(verifySpy).not.toHaveBeenCalled();
expect(consumeSpy).toHaveBeenCalledWith("user-abc");
});
it("falls back to IP when the cache lookup throws", async () => {
const middleware = defaultRateLimiter();
const consumeSpy = jest
.spyOn(RateLimiter.defaultRateLimiter, "consume")
.mockResolvedValue({} as never);
jest
.spyOn(RateLimiter, "getCachedUserIdForToken")
.mockRejectedValue(new Error("redis down"));
const mockCtx = {
path: "/some/path",
mountPath: undefined,
ip: "192.168.1.1",
set: jest.fn(),
request: { get: () => "Bearer some-token" },
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(consumeSpy).toHaveBeenCalledWith("192.168.1.1");
});
it("prefixes the key with fullPath when a custom limiter is registered", async () => {
const registerMiddleware = rateLimiter({ duration: 60, requests: 5 });
const registerCtx = {
path: "/documents.export",
mountPath: "/api",
ip: "127.0.0.1",
set: jest.fn(),
request: {},
} as unknown as Context;
await registerMiddleware(registerCtx, jest.fn());
const customLimiter = RateLimiter.getRateLimiter("/api/documents.export");
const consumeSpy = jest
.spyOn(customLimiter, "consume")
.mockResolvedValue({} as never);
jest
.spyOn(RateLimiter, "getCachedUserIdForToken")
.mockResolvedValue("user-abc");
const middleware = defaultRateLimiter();
const mockCtx = {
path: "/documents.export",
mountPath: "/api",
ip: "127.0.0.1",
set: jest.fn(),
request: { get: () => "Bearer verified-token" },
cookies: { get: () => undefined },
} as unknown as Context;
await middleware(mockCtx, jest.fn());
expect(consumeSpy).toHaveBeenCalledWith("/api/documents.export:user-abc");
});
});
});
+21 -13
View File
@@ -7,33 +7,41 @@ import Metrics from "@server/logging/Metrics";
import { ApiKey, OAuthAuthentication } from "@server/models";
import Redis from "@server/storage/redis";
import type { AppContext } from "@server/types";
import { getJWTPayload } from "@server/utils/jwt";
import { getUserForJWT } from "@server/utils/jwt";
import RateLimiter from "@server/utils/RateLimiter";
import { parseAuthentication } from "./authentication";
/**
* Returns a unique identifier for rate limiting based on the request context.
* Combines the user ID from the JWT payload with the client's IP address for
* authenticated requests, otherwise falls back to the client's IP address alone.
* Keys on the user id (so users behind a shared NAT don't share a bucket) when
* a token can be associated with a user, otherwise falls back to the client's
* IP address.
*
* @param ctx The application context.
* @returns A string identifier for rate limiting.
*/
function getRateLimiterIdentifier(ctx: AppContext): string {
async function getRateLimiterIdentifier(ctx: AppContext): Promise<string> {
try {
const { token } = parseAuthentication(ctx);
if (token && !ApiKey.match(token) && !OAuthAuthentication.match(token)) {
// Note: JWT is not validated here which would require a DB request,
// just decoded to extract the user ID for separating rate limits by user
// on shared networks.
const payload = getJWTPayload(token);
if (payload.id) {
return `user:${payload.id}:${ctx.ip}`;
}
if (!token) {
return ctx.ip;
}
if (ApiKey.match(token) || OAuthAuthentication.match(token)) {
return ctx.ip;
}
let userId = await RateLimiter.getCachedUserIdForToken(token);
if (!userId) {
const { user } = await getUserForJWT(token);
userId = user.id;
void RateLimiter.cacheUserForToken(token, userId);
}
return userId;
} catch {
// Fall through to IP-based rate limiting
}
return ctx.ip;
}
@@ -51,7 +59,7 @@ export function defaultRateLimiter() {
}
const fullPath = `${ctx.mountPath ?? ""}${ctx.path}`;
const identifier = getRateLimiterIdentifier(ctx);
const identifier = await getRateLimiterIdentifier(ctx);
const key = RateLimiter.hasRateLimiter(fullPath)
? `${fullPath}:${identifier}`
+4 -1
View File
@@ -20,6 +20,7 @@ import {
import ValidateSSOAccessTask from "@server/queues/tasks/ValidateSSOAccessTask";
import type { APIContext } from "@server/types";
import { getSessionsInCookie } from "@server/utils/authentication";
import RateLimiter from "@server/utils/RateLimiter";
import type * as T from "./schema";
const router = new Router();
@@ -187,7 +188,7 @@ router.post(
transaction(),
async (ctx: APIContext<T.AuthDeleteReq>) => {
const { auth, transaction } = ctx.state;
const { user } = auth;
const { user, token } = auth;
await user.rotateJwtSecret({ transaction });
await Event.createFromContext(ctx, {
@@ -198,6 +199,8 @@ router.post(
},
});
void RateLimiter.clearCachedToken(token);
ctx.cookies.set("accessToken", "", {
sameSite: "lax",
expires: subMinutes(new Date(), 1),
+63
View File
@@ -1,6 +1,8 @@
import { createHash } from "crypto";
import type { IRateLimiterStoreOptions } from "rate-limiter-flexible";
import { RateLimiterRedis, RateLimiterMemory } from "rate-limiter-flexible";
import env from "@server/env";
import Logger from "@server/logging/Logger";
import Redis from "@server/storage/redis";
export default class RateLimiter {
@@ -10,6 +12,10 @@ export default class RateLimiter {
static readonly RATE_LIMITER_REDIS_KEY_PREFIX = "rl";
static readonly TOKEN_CACHE_KEY_PREFIX = "rl:tok:";
static readonly TOKEN_CACHE_TTL_SECONDS = 3600;
static readonly rateLimiterMap = new Map<string, RateLimiterRedis>();
static readonly insuranceRateLimiter = new RateLimiterMemory({
@@ -44,6 +50,63 @@ export default class RateLimiter {
static hasRateLimiter(path: string): boolean {
return this.rateLimiterMap.has(path);
}
/**
* Caches the user id associated with a verified authentication token so that
* subsequent requests can be keyed by user without re-validating the token.
* Errors are swallowed — a failed cache write just means the next request
* falls back to IP-based keying.
*
* @param token The authentication token that was just verified.
* @param userId The id of the user the token belongs to.
*/
static async cacheUserForToken(token: string, userId: string): Promise<void> {
try {
await Redis.defaultClient.set(
this.tokenCacheKey(token),
userId,
"EX",
this.TOKEN_CACHE_TTL_SECONDS
);
} catch (err) {
Logger.warn("Failed to cache user for rate limiter token", err);
}
}
/**
* Looks up the cached user id for a previously verified token. Returns null
* on cache miss or Redis error.
*
* @param token The authentication token presented on the current request.
* @returns The associated user id, or null if unknown.
*/
static async getCachedUserIdForToken(token: string): Promise<string | null> {
try {
return await Redis.defaultClient.get(this.tokenCacheKey(token));
} catch (err) {
Logger.warn("Failed to read cached user for rate limiter token", err);
return null;
}
}
/**
* Removes the cached user id for a token, for example on logout so that a
* revoked token immediately stops keying rate limits per user.
*
* @param token The authentication token being invalidated.
*/
static async clearCachedToken(token: string): Promise<void> {
try {
await Redis.defaultClient.del(this.tokenCacheKey(token));
} catch (err) {
Logger.warn("Failed to clear cached rate limiter token", err);
}
}
private static tokenCacheKey(token: string): string {
const hash = createHash("sha256").update(token).digest("hex");
return `${this.TOKEN_CACHE_KEY_PREFIX}${hash}`;
}
}
/**