From 879d2b8198ad718e684f9202bcd4c6c5d9f32a26 Mon Sep 17 00:00:00 2001 From: Tom Moor Date: Sat, 16 May 2026 19:56:21 -0400 Subject: [PATCH] fix: Allow connecting additional auth providers on custom domain (#12364) * fix: Unable to link secondary auth provider on custom domain * doc * chore: Custom -> Apex transfer token * Refactor, address security concerns * Ensure OAuth intent is single-use * Secure OAuth state actor binding * Use scrypt for OAuth actor session binding --- plugins/azure/server/auth/azure.ts | 2 + plugins/discord/server/auth/discord.ts | 2 + plugins/google/server/auth/google.ts | 3 +- plugins/oidc/server/auth/oidcRouter.ts | 3 +- plugins/slack/server/auth/slack.ts | 3 +- server/commands/accountProvisioner.ts | 2 +- server/middlewares/passport.ts | 8 +- server/routes/api/documents/documents.ts | 4 +- server/routes/api/emojis/emojis.ts | 2 +- server/routes/api/shares/shares.ts | 6 +- server/routes/api/urls/urls.ts | 2 +- server/routes/app.ts | 2 +- server/routes/index.ts | 4 +- server/routes/oauth/index.ts | 2 +- server/types.ts | 3 + server/utils/oauthState.test.ts | 101 +++++++ server/utils/oauthState.ts | 177 ++++++++++++ server/utils/passport.ts | 325 ++++++++++++++++------- 18 files changed, 541 insertions(+), 110 deletions(-) create mode 100644 server/utils/oauthState.test.ts create mode 100644 server/utils/oauthState.ts diff --git a/plugins/azure/server/auth/azure.ts b/plugins/azure/server/auth/azure.ts index 1ab4ee0f4c..6b0deee13d 100644 --- a/plugins/azure/server/auth/azure.ts +++ b/plugins/azure/server/auth/azure.ts @@ -17,6 +17,7 @@ import { getTeamFromContext, getClientFromOAuthState, getUserFromOAuthState, + startOAuthFlow, } from "@server/utils/passport"; import config from "../../plugin.json"; import env from "../env"; @@ -143,6 +144,7 @@ if (env.AZURE_CLIENT_ID && env.AZURE_CLIENT_SECRET) { passport.use(strategy); router.get( config.id, + startOAuthFlow, passport.authenticate(config.id, { prompt: "select_account" }) ); router.get(`${config.id}.callback`, passportMiddleware(config.id)); diff --git a/plugins/discord/server/auth/discord.ts b/plugins/discord/server/auth/discord.ts index 78e34db1fa..00d940243e 100644 --- a/plugins/discord/server/auth/discord.ts +++ b/plugins/discord/server/auth/discord.ts @@ -24,6 +24,7 @@ import { getClientFromOAuthState, getUserFromOAuthState, request, + startOAuthFlow, } from "@server/utils/passport"; import config from "../../plugin.json"; import env from "../env"; @@ -226,6 +227,7 @@ if (env.DISCORD_CLIENT_ID && env.DISCORD_CLIENT_SECRET) { router.get( config.id, + startOAuthFlow, passport.authenticate(config.id, { scope, }) diff --git a/plugins/google/server/auth/google.ts b/plugins/google/server/auth/google.ts index b030950373..dbf241a8fa 100644 --- a/plugins/google/server/auth/google.ts +++ b/plugins/google/server/auth/google.ts @@ -19,6 +19,7 @@ import { getTeamFromContext, getClientFromOAuthState, getUserFromOAuthState, + startOAuthFlow, } from "@server/utils/passport"; import config from "../../plugin.json"; import env from "../env"; @@ -151,7 +152,7 @@ if (env.GOOGLE_CLIENT_ID && env.GOOGLE_CLIENT_SECRET) { ) ); - router.get(config.id, async (ctx, next) => { + router.get(config.id, startOAuthFlow, async (ctx, next) => { const team = await getTeamFromContext(ctx, { includeHostQueryParam: true, }); diff --git a/plugins/oidc/server/auth/oidcRouter.ts b/plugins/oidc/server/auth/oidcRouter.ts index 89eaa00cb8..4d8a98ca11 100644 --- a/plugins/oidc/server/auth/oidcRouter.ts +++ b/plugins/oidc/server/auth/oidcRouter.ts @@ -22,6 +22,7 @@ import { getClientFromOAuthState, getUserFromOAuthState, request, + startOAuthFlow, } from "@server/utils/passport"; import config from "../../plugin.json"; import env from "../env"; @@ -227,7 +228,7 @@ export function createOIDCRouter( ) ); - router.get(config.id, passport.authenticate(config.id)); + router.get(config.id, startOAuthFlow, passport.authenticate(config.id)); router.get(`${config.id}.callback`, passportMiddleware(config.id)); router.post(`${config.id}.callback`, passportMiddleware(config.id)); } diff --git a/plugins/slack/server/auth/slack.ts b/plugins/slack/server/auth/slack.ts index 4e5b12ba9f..6906702fe2 100644 --- a/plugins/slack/server/auth/slack.ts +++ b/plugins/slack/server/auth/slack.ts @@ -25,6 +25,7 @@ import { getTeamFromContext, getUserFromOAuthState, StateStore, + startOAuthFlow, } from "@server/utils/passport"; import { parseEmail } from "@shared/utils/email"; import env from "../env"; @@ -134,7 +135,7 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) { strategy.name = providerName; passport.use(strategy); - router.get("slack", passport.authenticate(providerName)); + router.get("slack", startOAuthFlow, passport.authenticate(providerName)); router.get("slack.callback", passportMiddleware(providerName)); router.get( diff --git a/server/commands/accountProvisioner.ts b/server/commands/accountProvisioner.ts index 448c9730a0..5992b0cc33 100644 --- a/server/commands/accountProvisioner.ts +++ b/server/commands/accountProvisioner.ts @@ -99,7 +99,7 @@ async function accountProvisioner( const actor = ctx.state.auth?.user; // If the user is already logged in and is an admin of the team then we - // allow them to connect a new authentication provider + // allow them to connect a new authentication provider. if (actor && actor.teamId === teamParams.teamId && actor.isAdmin) { const team = actor.team; const authenticationProvider = await AuthenticationProvider.findOne({ diff --git a/server/middlewares/passport.ts b/server/middlewares/passport.ts index f492701a7a..7bcb517f84 100644 --- a/server/middlewares/passport.ts +++ b/server/middlewares/passport.ts @@ -71,15 +71,17 @@ export default function createMiddleware(providerName: string) { // same domain or subdomain that they originated from (found in state). // get original host - const stateString = ctx.cookies.get("state"); + const stateString = + typeof ctx.query.state === "string" ? ctx.query.state : undefined; const state = stateString ? parseState(stateString) : undefined; + const oauthState = ctx.state.oauthState ?? state; // form a URL object with the err.redirectPath and replace the host const reqProtocol = - state?.client === Client.Desktop ? "outline" : ctx.protocol; + oauthState?.client === Client.Desktop ? "outline" : ctx.protocol; const requestHost = await getValidatedHost( - state?.host ?? ctx.hostname + oauthState?.host ?? ctx.hostname ); const url = new URL( env.isCloudHosted diff --git a/server/routes/api/documents/documents.ts b/server/routes/api/documents/documents.ts index e8f3cac070..8add886761 100644 --- a/server/routes/api/documents/documents.ts +++ b/server/routes/api/documents/documents.ts @@ -549,7 +549,7 @@ router.post( const { user } = ctx.state.auth; const apiVersion = getAPIVersion(ctx); const teamFromCtx = await getTeamFromContext(ctx, { - includeStateCookie: false, + includeOAuthState: false, }); let document: Document | null; @@ -1103,7 +1103,7 @@ router.post( if (shareId) { const teamFromCtx = await getTeamFromContext(ctx, { - includeStateCookie: false, + includeOAuthState: false, }); const result = await loadPublicShare({ id: shareId, diff --git a/server/routes/api/emojis/emojis.ts b/server/routes/api/emojis/emojis.ts index d874c573be..edbded9f73 100644 --- a/server/routes/api/emojis/emojis.ts +++ b/server/routes/api/emojis/emojis.ts @@ -85,7 +85,7 @@ router.get( if (shareId) { const teamFromCtx = await getTeamFromContext(ctx, { - includeStateCookie: false, + includeOAuthState: false, }); const { share } = await loadPublicShare({ id: shareId, diff --git a/server/routes/api/shares/shares.ts b/server/routes/api/shares/shares.ts index a5ec7abbd7..c4a3e94194 100644 --- a/server/routes/api/shares/shares.ts +++ b/server/routes/api/shares/shares.ts @@ -56,7 +56,7 @@ router.post( const { id, collectionId, documentId } = ctx.input.body; const { user } = ctx.state.auth; const teamFromCtx = await getTeamFromContext(ctx, { - includeStateCookie: false, + includeOAuthState: false, }); // only public link loads will send "id". @@ -440,7 +440,7 @@ router.get( validate(T.SharesSitemapSchema), async (ctx: APIContext) => { const { id } = ctx.input.query; - const team = await getTeamFromContext(ctx, { includeStateCookie: false }); + const team = await getTeamFromContext(ctx, { includeOAuthState: false }); const { share, sharedTree } = await loadPublicShare({ id, @@ -473,7 +473,7 @@ router.post( const { shareId, documentId, email } = ctx.input.body; const { transaction } = ctx.state; - const team = await getTeamFromContext(ctx, { includeStateCookie: false }); + const team = await getTeamFromContext(ctx, { includeOAuthState: false }); // Validate the share exists and is published const { share, document } = await loadPublicShare({ diff --git a/server/routes/api/urls/urls.ts b/server/routes/api/urls/urls.ts index df8ed9cb81..93b5ead181 100644 --- a/server/routes/api/urls/urls.ts +++ b/server/routes/api/urls/urls.ts @@ -55,7 +55,7 @@ router.post( let teamId: string | undefined = actor?.teamId; if (!teamId && !isUUID(shareId)) { const teamFromCtx = await getTeamFromContext(ctx, { - includeStateCookie: false, + includeOAuthState: false, }); teamId = teamFromCtx?.id; } diff --git a/server/routes/app.ts b/server/routes/app.ts index 510dc59cb3..0c03789dc0 100644 --- a/server/routes/app.ts +++ b/server/routes/app.ts @@ -215,7 +215,7 @@ export const renderShare = async (ctx: Context, next: Next) => { let sharedTree; try { - team = await getTeamFromContext(ctx, { includeStateCookie: false }); + team = await getTeamFromContext(ctx, { includeOAuthState: false }); const result = await loadPublicShare({ id: shareId, collectionId: collectionSlug, diff --git a/server/routes/index.ts b/server/routes/index.ts index 7bf8bf6c13..4beb87e3fb 100644 --- a/server/routes/index.ts +++ b/server/routes/index.ts @@ -124,7 +124,7 @@ router.get( const origin = env.isCloudHosted ? ctx.request.URL.origin : new URL(env.URL).origin; - const team = await getTeamFromContext(ctx, { includeStateCookie: false }); + const team = await getTeamFromContext(ctx, { includeOAuthState: false }); const mcpEnabled = team?.getPreference(TeamPreference.MCP) ?? true; ctx.body = { @@ -151,7 +151,7 @@ router.get( "/.well-known/oauth-protected-resource/mcp", ], async (ctx) => { - const team = await getTeamFromContext(ctx, { includeStateCookie: false }); + const team = await getTeamFromContext(ctx, { includeOAuthState: false }); const mcpEnabled = team?.getPreference(TeamPreference.MCP) ?? true; if (!mcpEnabled) { diff --git a/server/routes/oauth/index.ts b/server/routes/oauth/index.ts index 84225be7ce..7151ec5e45 100644 --- a/server/routes/oauth/index.ts +++ b/server/routes/oauth/index.ts @@ -194,7 +194,7 @@ router.post( } = ctx.input.body; const team = await getTeamFromContext(ctx, { - includeStateCookie: false, + includeOAuthState: false, }); if (!team) { throw NotFoundError(); diff --git a/server/types.ts b/server/types.ts index 5d419dc290..bb14d03d4e 100644 --- a/server/types.ts +++ b/server/types.ts @@ -14,6 +14,7 @@ import type { } from "@shared/types"; import type { BaseSchema } from "@server/routes/api/schema"; import type { AccountProvisionerResult } from "./commands/accountProvisioner"; +import type { OAuthIntent, OAuthState } from "./utils/oauthState"; import type { AccessRequest, ApiKey, @@ -77,6 +78,8 @@ export type AppState = { transaction: Transaction; pagination: Pagination; oauthClient?: OAuthClient; + oauthIntent?: OAuthIntent; + oauthState?: OAuthState; }; export type AppContext = ParameterizedContext; diff --git a/server/utils/oauthState.test.ts b/server/utils/oauthState.test.ts new file mode 100644 index 0000000000..3d3ef55270 --- /dev/null +++ b/server/utils/oauthState.test.ts @@ -0,0 +1,101 @@ +import { Client } from "@shared/types"; +import env from "@server/env"; +import { + hashOAuthStateNonce, + signOAuthIntent, + signOAuthState, + verifyOAuthIntent, + verifyOAuthState, +} from "./oauthState"; + +describe("oauthState", () => { + const originalSecretKey = env.SECRET_KEY; + + afterEach(() => { + env.SECRET_KEY = originalSecretKey; + }); + + it("round-trips a signed OAuth intent", () => { + const token = signOAuthIntent({ + host: "docs.example.com", + actorId: "user-id", + actorSessionHash: "session-hash", + client: Client.Web, + }); + + const payload = verifyOAuthIntent(token); + + expect(payload.host).toBe("docs.example.com"); + expect(payload.actorId).toBe("user-id"); + expect(payload.actorSessionHash).toBe("session-hash"); + expect(payload.client).toBe(Client.Web); + expect(payload.type).toBe("oauth_intent"); + expect(payload.exp).toBeGreaterThan(payload.iat); + }); + + it("round-trips a signed OAuth state", () => { + const token = signOAuthState({ + host: "team.outline.dev", + actorId: "user-id", + actorSessionHash: "session-hash", + client: Client.Desktop, + codeVerifier: "pkce-verifier", + nonceHash: hashOAuthStateNonce("csrf-nonce"), + }); + + const payload = verifyOAuthState(token); + + expect(payload.host).toBe("team.outline.dev"); + expect(payload.actorId).toBe("user-id"); + expect(payload.actorSessionHash).toBe("session-hash"); + expect(payload.client).toBe(Client.Desktop); + expect(payload.type).toBe("oauth_state"); + expect(payload.codeVerifier).toBe("pkce-verifier"); + expect(payload.nonceHash).toBe(hashOAuthStateNonce("csrf-nonce")); + }); + + it("rejects a signed OAuth state as an OAuth intent", () => { + const token = signOAuthState({ + host: "team.outline.dev", + actorId: "user-id", + client: Client.Web, + nonceHash: hashOAuthStateNonce("csrf-nonce"), + }); + + expect(() => verifyOAuthIntent(token)).toThrow("Invalid OAuth intent"); + }); + + it("rejects a signed OAuth intent as an OAuth state", () => { + const token = signOAuthIntent({ + host: "docs.example.com", + actorId: "user-id", + client: Client.Web, + }); + + expect(() => verifyOAuthState(token)).toThrow("Invalid OAuth state"); + }); + + it("rejects a tampered token", () => { + const token = signOAuthState({ + host: "team.outline.dev", + client: Client.Web, + nonceHash: hashOAuthStateNonce("csrf-nonce"), + }); + const tamperedToken = `${token}tampered`; + + expect(() => verifyOAuthState(tamperedToken)).toThrow( + "Invalid OAuth state" + ); + }); + + it("rejects tokens signed with another secret", () => { + const token = signOAuthIntent({ + host: "docs.example.com", + client: Client.Web, + }); + + env.SECRET_KEY = "1".repeat(64); + + expect(() => verifyOAuthIntent(token)).toThrow("Invalid OAuth state"); + }); +}); diff --git a/server/utils/oauthState.ts b/server/utils/oauthState.ts new file mode 100644 index 0000000000..0fb167becd --- /dev/null +++ b/server/utils/oauthState.ts @@ -0,0 +1,177 @@ +import JWT from "jsonwebtoken"; +import { Client } from "@shared/types"; +import env from "@server/env"; +import { OAuthStateMismatchError } from "@server/errors"; +import { hash } from "./crypto"; + +const Algorithm = "HS256"; +const ExpiresInSeconds = 10 * 60; +const IntentType = "oauth_intent"; +const StateType = "oauth_state"; + +interface OAuthIntentInput { + host: string; + actorId?: string; + actorSessionHash?: string; + client: Client; +} + +interface OAuthStateInput extends OAuthIntentInput { + codeVerifier?: string; + nonceHash: string; +} + +interface OAuthIntentClaims extends OAuthIntentInput { + type: typeof IntentType; +} + +interface OAuthStateClaims extends OAuthStateInput { + type: typeof StateType; +} + +export interface OAuthIntent extends OAuthIntentClaims { + iat: number; + exp: number; +} + +export interface OAuthState extends OAuthStateClaims { + iat: number; + exp: number; +} + +/** + * Hashes an OAuth CSRF nonce for storage in signed OAuth state. + * + * @param nonce the nonce stored in the browser cookie. + * @returns the sha256 hash of the nonce. + */ +export function hashOAuthStateNonce(nonce: string): string { + return hash(nonce); +} + +/** + * Creates a short-lived signed OAuth intent token. + * + * @param payload the intent values to sign. + * @returns the signed intent token. + */ +export function signOAuthIntent(payload: OAuthIntentInput): string { + return sign({ + ...payload, + type: IntentType, + }); +} + +/** + * Verifies a signed OAuth intent token. + * + * @param token the token to verify. + * @returns the verified intent payload. + * @throws {OAuthStateMismatchError} if the token is missing, expired, invalid, + * or has an unexpected payload shape. + */ +export function verifyOAuthIntent(token: string): OAuthIntent { + const payload = verify(token); + + if (!isOAuthIntent(payload)) { + throw OAuthStateMismatchError("Invalid OAuth intent"); + } + + return payload; +} + +/** + * Creates a short-lived signed OAuth state token. + * + * @param payload the OAuth state values to sign. + * @returns the signed OAuth state token. + */ +export function signOAuthState(payload: OAuthStateInput): string { + return sign({ + ...payload, + type: StateType, + }); +} + +/** + * Verifies a signed OAuth state token. + * + * @param token the token to verify. + * @returns the verified OAuth state payload. + * @throws {OAuthStateMismatchError} if the token is missing, expired, invalid, + * or has an unexpected payload shape. + */ +export function verifyOAuthState(token: string): OAuthState { + const payload = verify(token); + + if (!isOAuthState(payload)) { + throw OAuthStateMismatchError("Invalid OAuth state"); + } + + return payload; +} + +function sign(payload: OAuthIntentClaims | OAuthStateClaims): string { + return JWT.sign(payload, env.SECRET_KEY, { + algorithm: Algorithm, + expiresIn: ExpiresInSeconds, + }); +} + +function verify(token: string): JWT.JwtPayload { + try { + const payload = JWT.verify(token, env.SECRET_KEY, { + algorithms: [Algorithm], + }); + + if (typeof payload === "string") { + throw OAuthStateMismatchError("Invalid OAuth state"); + } + + return payload; + } catch (err) { + if (err instanceof Error && err.name === "TokenExpiredError") { + throw OAuthStateMismatchError("Expired OAuth state"); + } + + throw OAuthStateMismatchError("Invalid OAuth state"); + } +} + +function isOAuthIntent(payload: JWT.JwtPayload): payload is OAuthIntent { + return ( + payload.type === IntentType && + typeof payload.host === "string" && + isClient(payload.client) && + isOptionalString(payload.actorId) && + isOptionalString(payload.actorSessionHash) && + payload.nonceHash === undefined && + payload.codeVerifier === undefined && + typeof payload.iat === "number" && + typeof payload.exp === "number" + ); +} + +function isOAuthState(payload: JWT.JwtPayload): payload is OAuthState { + return ( + payload.type === StateType && + typeof payload.host === "string" && + isClient(payload.client) && + isOptionalString(payload.actorId) && + isOptionalString(payload.actorSessionHash) && + typeof payload.iat === "number" && + typeof payload.exp === "number" && + typeof payload.nonceHash === "string" && + isOptionalString(payload.codeVerifier) + ); +} + +function isClient(value: string | undefined): value is Client { + return value === Client.Desktop || value === Client.Web; +} + +function isOptionalString( + value: string | undefined +): value is string | undefined { + return value === undefined || typeof value === "string"; +} diff --git a/server/utils/passport.ts b/server/utils/passport.ts index cd1a15c874..117a1d684b 100644 --- a/server/utils/passport.ts +++ b/server/utils/passport.ts @@ -1,6 +1,6 @@ import crypto from "node:crypto"; import { addMinutes, subMinutes } from "date-fns"; -import type { Context } from "koa"; +import type { Context, Next } from "koa"; import type { StateStoreStoreCallback, StateStoreVerifyCallback, @@ -9,12 +9,91 @@ import type { Primitive } from "utility-types"; import { Client } from "@shared/types"; import { getCookieDomain, parseDomain } from "@shared/utils/domains"; import env from "@server/env"; -import { Team } from "@server/models"; +import { Team, User } from "@server/models"; +import Redis from "@server/storage/redis"; import { InternalError, OAuthStateMismatchError } from "../errors"; -import { safeEqual } from "./crypto"; +import { hash, safeEqual } from "./crypto"; import fetch from "./fetch"; import { getUserForJWT } from "./jwt"; +import { + hashOAuthStateNonce, + signOAuthIntent, + signOAuthState, + verifyOAuthIntent, + verifyOAuthState, +} from "./oauthState"; +const FLOW_QUERY_PARAM = "flow"; +const OAUTH_CSRF_COOKIE = "oauth_csrf"; +const OAUTH_INTENT_PREFIX = "oauth:intent:"; +const OAUTH_INTENT_TTL_SECONDS = 10 * 60; +const ACTOR_SESSION_HASH_KEYLEN = 64; + +/** + * Middleware for OAuth start routes that bridges cookie scopes between custom + * team domains and the apex (env.URL) where the OAuth callback always lands. + * + * The OAuth callback always lands on the apex domain, while a user's + * `accessToken` session cookie may be host-scoped to a custom team domain. To + * make the "connect a new auth provider while signed in" flow work from a + * custom domain: + * + * 1. On a custom team domain — create a short-lived signed intent containing + * the original host and actor id, then bounce to the apex with it. + * 2. On the apex — verify the signed intent and stash it on `ctx.state` so + * `StateStore.store` can fold it into the signed OAuth `state` parameter. + * + * Non-custom team subdomains skip the bounce because the start route can read + * the host-scoped session and set the OAuth CSRF cookie on the base domain for + * the apex callback. Self-hosted deployments have a single domain and pass + * through. + */ +export async function startOAuthFlow(ctx: Context, next: Next) { + if (!env.isCloudHosted) { + return next(); + } + + const apex = new URL(env.URL); + const onApex = ctx.hostname === apex.hostname; + const isCustom = parseDomain(ctx.hostname).custom; + + if (isCustom && !onApex) { + const url = new URL(ctx.originalUrl, apex); + const client = getClientFromInput(ctx); + const actor = await getOAuthActor(ctx); + const flow = signOAuthIntent({ + host: ctx.hostname, + actorId: actor?.id, + actorSessionHash: actor ? getActorSessionHash(actor) : undefined, + client, + }); + + url.searchParams.delete(FLOW_QUERY_PARAM); + url.searchParams.set(FLOW_QUERY_PARAM, flow); + await storeOAuthIntent(flow); + + return ctx.redirect(url.toString()); + } + + const flow = ctx.query[FLOW_QUERY_PARAM]; + if (onApex && typeof flow === "string" && flow) { + try { + const intent = verifyOAuthIntent(flow); + if (await consumeOAuthIntent(flow)) { + ctx.state.oauthIntent = intent; + } + } catch { + // Invalid or expired intent — proceed without an actor. + // The user can still complete the OAuth flow as a fresh sign-in. + } + } + + return next(); +} + +/** + * Passport OAuth state store backed by signed state and a CSRF nonce cookie. + */ export class StateStore { constructor(private pkce = false) {} @@ -27,8 +106,8 @@ export class StateStore { _meta?: unknown, cb?: StateStoreStoreCallback ) => { - // token is a short lived one-time pad to prevent replay attacks - const token = crypto.randomBytes(8).toString("hex"); + const context = getKoaContext(ctx); + const csrfNonce = crypto.randomBytes(16).toString("hex"); // Note parameters are based on whether PKCE is in use or not, this is parameters // of how the underlying library is architected, see: @@ -44,24 +123,35 @@ export class StateStore { // We expect host to be a team subdomain, custom domain, or apex domain // that is passed via query param from the auth provider component. - const clientInput = ctx.query.client?.toString(); - const client = clientInput === Client.Desktop ? Client.Desktop : Client.Web; - const host = ctx.query.host?.toString() || parseDomain(ctx.hostname).host; - const accessToken = ctx.cookies.get("accessToken"); - const state = buildState({ + const client = + context.state.oauthIntent?.client ?? getClientFromInput(context); + const host = + context.state.oauthIntent?.host ?? + context.query.host?.toString() ?? + parseDomain(context.hostname).host; + const actorId = + context.state.oauthIntent?.actorId ?? getAuthenticatedUserId(context); + const actorSessionHash = + context.state.oauthIntent?.actorSessionHash ?? + getAuthenticatedUserSessionHash(context); + const state = signOAuthState({ host, - token, + actorId, + actorSessionHash, client, codeVerifier, - accessToken, + nonceHash: hashOAuthStateNonce(csrfNonce), }); - ctx.cookies.set(this.key, state, { + context.cookies.set(OAUTH_CSRF_COOKIE, csrfNonce, { + httpOnly: true, + sameSite: "lax", + secure: env.isProduction, expires: addMinutes(new Date(), 10), - domain: getCookieDomain(ctx.hostname, env.isCloudHosted), + domain: getCookieDomain(context.hostname, env.isCloudHosted), }); - callback(null, token); + callback(null, state); }; verify = ( @@ -69,34 +159,35 @@ export class StateStore { providedToken: string, callback: StateStoreVerifyCallback ) => { - const state = ctx.cookies.get(this.key); - - if (!state) { - return callback( - OAuthStateMismatchError("No state was available after OAuth flow"), - false, - state - ); - } - - const { token, codeVerifier } = parseState(state); - - // Destroy the one-time pad token and ensure it matches - ctx.cookies.set(this.key, "", { + const context = getKoaContext(ctx); + const csrfNonce = context.cookies.get(OAUTH_CSRF_COOKIE); + context.cookies.set(OAUTH_CSRF_COOKIE, "", { + httpOnly: true, + sameSite: "lax", + secure: env.isProduction, expires: subMinutes(new Date(), 1), - domain: getCookieDomain(ctx.hostname, env.isCloudHosted), + domain: getCookieDomain(context.hostname, env.isCloudHosted), }); - if (!safeEqual(token, providedToken)) { + let state; + try { + state = verifyOAuthState(providedToken); + } catch (err) { + return callback(err, false, providedToken); + } + + if (!safeEqual(hashOAuthStateNonce(csrfNonce ?? ""), state.nonceHash)) { return callback( - OAuthStateMismatchError("Token in state mismatched"), + OAuthStateMismatchError("OAuth CSRF nonce mismatched"), false, - token + providedToken ); } + context.state.oauthState = state; + // @ts-expect-error Type in library is wrong - callback(null, codeVerifier ?? true, state); + callback(null, state.codeVerifier ?? true, providedToken); }; } @@ -124,34 +215,18 @@ export async function request( } } -function buildState({ - host, - token, - client, - codeVerifier, - accessToken, -}: { - host: string; - token: string; - client?: Client; - codeVerifier?: string; - accessToken?: string; -}) { - return [host, token, client, codeVerifier, accessToken].join("|"); -} - /** * Parses the state string into its components. * * @param state The state string - * @returns An object containing the parsed components + * @returns An object containing the parsed components, if valid. */ export function parseState(state: string) { - const [host, token, client, rawCodeVerifier, rawAccessToken] = - state.split("|"); - const codeVerifier = rawCodeVerifier ? rawCodeVerifier : undefined; - const accessToken = rawAccessToken ? rawAccessToken : undefined; - return { host, token, client, codeVerifier, accessToken }; + try { + return verifyOAuthState(state); + } catch { + return undefined; + } } /** @@ -162,52 +237,45 @@ export function parseState(state: string) { * @returns The client type, defaults to Client.Web */ export function getClientFromOAuthState(ctx: Context): Client { - const state = ctx.cookies.get("state"); - const client = state ? parseState(state).client : undefined; + const context = getKoaContext(ctx); + const client = context.state.oauthState?.client; return client === Client.Desktop ? Client.Desktop : Client.Web; } /** - * Returns the access token from the context if available. This is used - * to restore the session during the OAuth flow when connecting additional - * providers to an existing team. + * Returns the actor referenced by verified OAuth state, if available. This is + * used to restore the originating user during the OAuth flow when connecting + * additional providers to an existing team. * * @param ctx The Koa context - * @returns The access token if available, otherwise undefined - */ -export function getAccessTokenFromOAuthState(ctx: Context): string | undefined { - const state = ctx.cookies.get("state"); - return state ? parseState(state).accessToken : undefined; -} - -/** - * Returns the user from the context if they are authenticated. This is used - * to restore the session during the OAuth flow. - * - * @param ctx The Koa context - * @returns The user if authenticated, otherwise undefined + * @returns The actor if available, otherwise undefined */ export async function getUserFromOAuthState(ctx: Context) { - const token = getAccessTokenFromOAuthState(ctx); - if (!token) { + const context = getKoaContext(ctx); + const state = context.state.oauthState; + if (!state?.actorId || !state.actorSessionHash) { return undefined; } - try { - const { user } = await getUserForJWT(token); - return user; - } catch (_err) { + const user = await User.scope("withTeam").findByPk(state.actorId); + if (!user) { return undefined; } + + if (!safeEqual(getActorSessionHash(user), state.actorSessionHash)) { + return undefined; + } + + return user; } type TeamFromContextOptions = { /** - * Whether to consider the state cookie in the context when determining the team. - * If true, the state cookie will be parsed to determine the host and infer the team + * Whether to consider OAuth state in the context when determining the team. + * If true, OAuth state will be used to determine the host and infer the team * this should only be used in the authentication process. */ - includeStateCookie?: boolean; + includeOAuthState?: boolean; /** * Whether to consider the host query parameter in the context when determining the team. * If true, the host query parameter will be used to determine the host and infer the team @@ -216,7 +284,7 @@ type TeamFromContextOptions = { }; /** - * Infers the team from the context based on the hostname or state cookie. + * Infers the team from the context based on the hostname or OAuth state. * * @param ctx The Koa context * @param options Options for determining the team @@ -224,18 +292,20 @@ type TeamFromContextOptions = { */ export async function getTeamFromContext( ctx: Context, - options: TeamFromContextOptions = { includeStateCookie: true } + options: TeamFromContextOptions = { includeOAuthState: true } ) { + const context = getKoaContext(ctx); // "domain" is the domain the user came from when attempting auth // we use it to infer the team they intend on signing into - const state = options.includeStateCookie - ? ctx.cookies.get("state") + const includeOAuthState = options.includeOAuthState ?? true; + const state = includeOAuthState + ? (context.state.oauthState ?? context.state.oauthIntent) : undefined; const queryHost = - options.includeHostQueryParam && typeof ctx.query.host === "string" - ? ctx.query.host + options.includeHostQueryParam && typeof context.query.host === "string" + ? context.query.host : undefined; - const host = (state ? parseState(state).host : queryHost) || ctx.hostname; + const host = state?.host ?? queryHost ?? context.hostname; const domain = parseDomain(host); let team; @@ -247,8 +317,8 @@ export async function getTeamFromContext( order: [["createdAt", "DESC"]], }); } - } else if (ctx.state?.rootShare) { - team = await Team.findByPk(ctx.state.rootShare.teamId); + } else if (context.state?.rootShare) { + team = await Team.findByPk(context.state.rootShare.teamId); } else if (domain.custom) { team = await Team.findByDomain(domain.host); } else if (domain.teamSubdomain) { @@ -257,3 +327,74 @@ export async function getTeamFromContext( return team; } + +function getClientFromInput(ctx: Context): Client { + const clientInput = ctx.query.client?.toString(); + return clientInput === Client.Desktop ? Client.Desktop : Client.Web; +} + +function getAuthenticatedUser(ctx: Context): User | undefined { + return ctx.state.auth && "user" in ctx.state.auth + ? ctx.state.auth.user + : undefined; +} + +function getAuthenticatedUserId(ctx: Context): string | undefined { + return getAuthenticatedUser(ctx)?.id; +} + +function getAuthenticatedUserSessionHash(ctx: Context): string | undefined { + const user = getAuthenticatedUser(ctx); + return user ? getActorSessionHash(user) : undefined; +} + +async function getOAuthActor(ctx: Context): Promise { + const authenticatedUser = getAuthenticatedUser(ctx); + if (authenticatedUser) { + return authenticatedUser; + } + + const accessToken = ctx.cookies.get("accessToken"); + if (!accessToken) { + return undefined; + } + + try { + const { user } = await getUserForJWT(accessToken); + return user; + } catch { + return undefined; + } +} + +function getActorSessionHash(user: User): string { + return crypto + .scryptSync( + user.jwtSecret, + `oauth-actor-session:${env.SECRET_KEY}:${user.id}`, + ACTOR_SESSION_HASH_KEYLEN + ) + .toString("hex"); +} + +async function storeOAuthIntent(token: string): Promise { + await Redis.defaultClient.set( + getOAuthIntentKey(token), + "1", + "EX", + OAUTH_INTENT_TTL_SECONDS + ); +} + +async function consumeOAuthIntent(token: string): Promise { + const result = await Redis.defaultClient.getdel(getOAuthIntentKey(token)); + return result === "1"; +} + +function getOAuthIntentKey(token: string): string { + return `${OAUTH_INTENT_PREFIX}${hash(token)}`; +} + +function getKoaContext(ctx: Context): Context { + return (ctx as Context & { ctx?: Context }).ctx ?? ctx; +}