diff --git a/plugins/email/server/auth/email.ts b/plugins/email/server/auth/email.ts index e329b180d8..2e7242962f 100644 --- a/plugins/email/server/auth/email.ts +++ b/plugins/email/server/auth/email.ts @@ -34,7 +34,7 @@ router.post( team = await Team.scope("withAuthenticationProviders").findOne(); } else if (domain.custom) { team = await Team.scope("withAuthenticationProviders").findOne({ - where: { domain: domain.host }, + where: { domain: domain.host.toLowerCase() }, }); } else if (domain.teamSubdomain) { team = await Team.scope("withAuthenticationProviders").findOne({ diff --git a/server/middlewares/passport.ts b/server/middlewares/passport.ts index 6f1b18be07..f492701a7a 100644 --- a/server/middlewares/passport.ts +++ b/server/middlewares/passport.ts @@ -2,13 +2,51 @@ import passport from "@outlinewiki/koa-passport"; import type { Context } from "koa"; import { InternalOAuthError } from "passport-oauth2"; import { Client } from "@shared/types"; +import { parseDomain } from "@shared/utils/domains"; import env from "@server/env"; -import { AuthenticationError, OAuthStateMismatchError } from "@server/errors"; +import { AuthenticationError } from "@server/errors"; import Logger from "@server/logging/Logger"; +import { Team } from "@server/models"; import type { AuthenticationResult } from "@server/types"; import { signIn } from "@server/utils/authentication"; import { parseState } from "@server/utils/passport"; +/** + * Validates that a host from the OAuth state is a trusted domain. For + * cloud-hosted deployments, ensures the host is either a known subdomain of + * the base domain or a registered custom domain. + * + * @param host the host to validate. + * @returns the host if trusted, otherwise falls back to the base domain from env.URL. + */ +async function getValidatedHost(host: string): Promise { + const fallback = new URL(env.URL).host; + + if (!env.isCloudHosted) { + return host; + } + + if (!host) { + return fallback; + } + + const domain = parseDomain(host); + + // Subdomains of the base domain are trusted + if (!domain.custom) { + return domain.host; + } + + // Custom domains must be registered to a team + const team = await Team.findByDomain(domain.host); + if (team) { + return domain.host; + } + + // Unrecognized host, fall back to the base app URL + return fallback; +} + export default function createMiddleware(providerName: string) { return function passportMiddleware(ctx: Context) { return passport.authorize( @@ -40,11 +78,9 @@ export default function createMiddleware(providerName: string) { const reqProtocol = state?.client === Client.Desktop ? "outline" : ctx.protocol; - // `state.host` cannot be trusted if the error is a state mismatch, use `ctx.hostname` - const requestHost = - err instanceof OAuthStateMismatchError - ? ctx.hostname - : (state?.host ?? ctx.hostname); + const requestHost = await getValidatedHost( + state?.host ?? ctx.hostname + ); const url = new URL( env.isCloudHosted ? `${reqProtocol}://${requestHost}${redirectPath}` diff --git a/server/models/Team.test.ts b/server/models/Team.test.ts index 9ece16cbe2..cbc4c1e1f0 100644 --- a/server/models/Team.test.ts +++ b/server/models/Team.test.ts @@ -1,7 +1,54 @@ import { randomUUID } from "node:crypto"; -import { buildTeam, buildCollection, buildAttachment } from "@server/test/factories"; +import { Team } from "@server/models"; +import { + buildTeam, + buildCollection, + buildAttachment, +} from "@server/test/factories"; describe("Team", () => { + describe("findByDomain", () => { + it("should find a team by its domain", async () => { + const domain = `${randomUUID()}.example.com`; + const team = await buildTeam({ domain }); + const result = await Team.findByDomain(domain); + expect(result?.id).toEqual(team.id); + }); + + it("should normalize domain to lowercase", async () => { + const id = randomUUID(); + const team = await buildTeam({ domain: `${id}.example.com` }); + const result = await Team.findByDomain(`${id}.Example.COM`); + expect(result?.id).toEqual(team.id); + }); + + it("should strip protocol from input", async () => { + const domain = `${randomUUID()}.example.com`; + const team = await buildTeam({ domain }); + const result = await Team.findByDomain(`https://${domain}`); + expect(result?.id).toEqual(team.id); + }); + + it("should strip port from input", async () => { + const domain = `${randomUUID()}.example.com`; + const team = await buildTeam({ domain }); + const result = await Team.findByDomain(`${domain}:3000`); + expect(result?.id).toEqual(team.id); + }); + + it("should strip path from input", async () => { + const domain = `${randomUUID()}.example.com`; + const team = await buildTeam({ domain }); + const result = await Team.findByDomain(`${domain}/some/path`); + expect(result?.id).toEqual(team.id); + }); + + it("should return null for unregistered domain", async () => { + const result = await Team.findByDomain("unknown.example.com"); + expect(result).toBeNull(); + }); + }); + describe("collectionIds", () => { it("should return non-private collection ids", async () => { const team = await buildTeam(); diff --git a/server/models/Team.ts b/server/models/Team.ts index 6fa258cab5..8d4d2f32f9 100644 --- a/server/models/Team.ts +++ b/server/models/Team.ts @@ -2,7 +2,7 @@ import crypto from "node:crypto"; import { URL } from "node:url"; import { subMinutes } from "date-fns"; import type { InferAttributes, InferCreationAttributes } from "sequelize"; -import { type SaveOptions } from "sequelize"; +import { type FindOptions, type SaveOptions } from "sequelize"; import { Op } from "sequelize"; import { Column, @@ -541,6 +541,26 @@ class Team extends ParanoidModel< } }; + /** + * Find a team by its custom domain. The input is normalized by stripping + * protocol, port, path, and lowercasing to match the stored format. + * + * @param domain the domain to search for. + * @param options additional find options to pass to the query. + * @returns the team with the given domain, or null if not found. + */ + static async findByDomain(domain: string, options?: FindOptions) { + const normalized = domain + .replace(/(https?:)?\/\//, "") + .split(/[/:?]/)[0] + .toLowerCase(); + + return this.findOne({ + ...options, + where: { ...options?.where, domain: normalized }, + }); + } + /** * Find a team by its current or previous subdomain. * diff --git a/server/routes/api/auth/auth.ts b/server/routes/api/auth/auth.ts index 47bdd45d7d..368e534829 100644 --- a/server/routes/api/auth/auth.ts +++ b/server/routes/api/auth/auth.ts @@ -55,7 +55,7 @@ router.post("auth.config", async (ctx: APIContext) => { if (domain.custom) { const team = await Team.scope("withAuthenticationProviders").findOne({ where: { - domain: ctx.request.hostname, + domain: ctx.request.hostname.toLowerCase(), }, }); diff --git a/server/routes/api/urls/urls.ts b/server/routes/api/urls/urls.ts index 72b3368eb5..cb8631ef77 100644 --- a/server/routes/api/urls/urls.ts +++ b/server/routes/api/urls/urls.ts @@ -207,11 +207,7 @@ router.post( const { hostname } = ctx.input.body; const [team, share] = await Promise.all([ - Team.findOne({ - where: { - domain: hostname, - }, - }), + Team.findByDomain(hostname), Share.findOne({ where: { domain: hostname, diff --git a/server/utils/passport.ts b/server/utils/passport.ts index 7d926ddaac..77a886155b 100644 --- a/server/utils/passport.ts +++ b/server/utils/passport.ts @@ -240,7 +240,7 @@ export async function getTeamFromContext( let team; if (!env.isCloudHosted) { if (env.ENVIRONMENT === "test") { - team = await Team.findOne({ where: { domain: env.URL } }); + team = await Team.findByDomain(env.URL); } else { team = await Team.findOne({ order: [["createdAt", "DESC"]], @@ -249,7 +249,7 @@ export async function getTeamFromContext( } else if (ctx.state?.rootShare) { team = await Team.findByPk(ctx.state.rootShare.teamId); } else if (domain.custom) { - team = await Team.findOne({ where: { domain: domain.host } }); + team = await Team.findByDomain(domain.host); } else if (domain.teamSubdomain) { team = await Team.findBySubdomain(domain.teamSubdomain); } diff --git a/shared/utils/domains.test.ts b/shared/utils/domains.test.ts index 93e5e0073c..1a9df4cf50 100644 --- a/shared/utils/domains.test.ts +++ b/shared/utils/domains.test.ts @@ -97,6 +97,24 @@ describe("#parseDomain", () => { }); }); + it("should strip userinfo before the hostname", () => { + expect(parseDomain("user:pass@example.com")).toMatchObject({ + teamSubdomain: "", + host: "example.com", + custom: false, + }); + expect(parseDomain("myteam.example.com@evil.com")).toMatchObject({ + teamSubdomain: "", + host: "evil.com", + custom: true, + }); + expect(parseDomain("https://myteam.example.com@evil.com")).toMatchObject({ + teamSubdomain: "", + host: "evil.com", + custom: true, + }); + }); + it("should recognize include private domains like blogspot.com as custom", () => { expect(parseDomain("foo.blogspot.com")).toMatchObject({ teamSubdomain: "", diff --git a/shared/utils/domains.ts b/shared/utils/domains.ts index 6312918b35..126505b7d0 100644 --- a/shared/utils/domains.ts +++ b/shared/utils/domains.ts @@ -18,10 +18,17 @@ export function slugifyDomain(domain: string) { return domain.split(".").slice(0, -1).join("-"); } -// strips protocol and whitespace from input -// then strips the path and query string +// strips protocol, userinfo, port, path, query, and whitespace from input +// to extract a clean hostname function normalizeUrl(url: string) { - return trim(url.replace(/(https?:)?\/\//, "")).split(/[/:?]/)[0]; + const stripped = trim(url.replace(/(https?:)?\/\//, "")); + // Extract authority (everything before the first slash) + const authority = stripped.split("/")[0]; + // Strip userinfo if present (e.g. "user:pass@host" → "host") + const atIndex = authority.lastIndexOf("@"); + const hostWithPort = + atIndex !== -1 ? authority.substring(atIndex + 1) : authority; + return hostWithPort.split(/[:?]/)[0]; } // The base domain is where root cookies are set in hosted mode