diff --git a/server/collaboration/AuthenticationExtension.ts b/server/collaboration/AuthenticationExtension.ts index 9ccf51d30e..d4dac54c13 100644 --- a/server/collaboration/AuthenticationExtension.ts +++ b/server/collaboration/AuthenticationExtension.ts @@ -19,7 +19,7 @@ export default class AuthenticationExtension implements Extension { throw AuthenticationError("Authentication required"); } - const user = await getUserForJWT(token, ["session", "collaboration"]); + const { user } = await getUserForJWT(token, ["session", "collaboration"]); const document = await Document.findByPk(documentId, { userId: user.id, }); diff --git a/server/middlewares/authentication.ts b/server/middlewares/authentication.ts index 0da8bb1963..7367152be2 100644 --- a/server/middlewares/authentication.ts +++ b/server/middlewares/authentication.ts @@ -37,7 +37,10 @@ type AuthInput = { export default function auth(options: AuthenticationOptions = {}) { return async function authMiddleware(ctx: AppContext, next: Next) { try { - const { type, token, user } = await validateAuthentication(ctx, options); + const { type, token, user, service } = await validateAuthentication( + ctx, + options + ); await Promise.all([ user.updateActiveAt(ctx), @@ -48,6 +51,7 @@ export default function auth(options: AuthenticationOptions = {}) { user, token, type, + service, }; if (tracer) { @@ -143,7 +147,12 @@ export function parseAuthentication(ctx: AppContext): AuthInput { async function validateAuthentication( ctx: AppContext, options: AuthenticationOptions -): Promise<{ user: User; token: string; type: AuthenticationType }> { +): Promise<{ + user: User; + token: string; + type: AuthenticationType; + service?: string; +}> { const { token, transport } = parseAuthentication(ctx); if (!token) { @@ -152,6 +161,7 @@ async function validateAuthentication( let user: User | null; let type: AuthenticationType; + let service: string | undefined; if (OAuthAuthentication.match(token)) { if (transport !== "header") { @@ -241,7 +251,9 @@ async function validateAuthentication( await apiKey.updateActiveAt(); } else { type = AuthenticationType.APP; - user = await getUserForJWT(token); + const result = await getUserForJWT(token); + user = result.user; + service = result.service; } if (user.isSuspended) { @@ -270,5 +282,6 @@ async function validateAuthentication( user, type, token, + service, }; } diff --git a/server/models/User.ts b/server/models/User.ts index f0c75b7279..cdd1016faf 100644 --- a/server/models/User.ts +++ b/server/models/User.ts @@ -579,14 +579,16 @@ class User extends ParanoidModel< * in the client browser cookies to remain logged in. * * @param expiresAt The time the token will expire at + * @param service The authentication service used to generate the token, if applicable * @returns The session token */ - getJwtToken = (expiresAt?: Date) => + getJwtToken = (expiresAt?: Date, service?: string) => JWT.sign( { id: this.id, expiresAt: expiresAt ? expiresAt.toISOString() : undefined, type: "session", + service, }, this.jwtSecret ); @@ -612,15 +614,17 @@ class User extends ParanoidModel< * between subdomains or domains. It has a short expiry and can only be used * once. * + * @param The authentication service used to generate the token, if applicable * @returns The transfer token */ - getTransferToken = () => + getTransferToken = (service?: string) => JWT.sign( { id: this.id, createdAt: new Date().toISOString(), expiresAt: addMinutes(new Date(), 1).toISOString(), type: "transfer", + service, }, this.jwtSecret ); @@ -629,6 +633,7 @@ class User extends ParanoidModel< * Returns a temporary token that is only used for logging in from an email * It can only be used to sign in once and has a medium length expiry * + * @param ctx The request context, used to get the IP address of the request * @returns The email signin token */ getEmailSigninToken = (ctx: Context) => diff --git a/server/routes/api/auth/auth.ts b/server/routes/api/auth/auth.ts index fd9d6d21e1..b13d47a8c6 100644 --- a/server/routes/api/auth/auth.ts +++ b/server/routes/api/auth/auth.ts @@ -114,8 +114,11 @@ router.post("auth.config", async (ctx: APIContext) => { }; }); +/** Authentication services that don't require SSO validation. */ +const NON_SSO_SERVICES = ["email", "passkeys"]; + router.post("auth.info", auth(), async (ctx: APIContext) => { - const { user } = ctx.state.auth; + const { user, service } = ctx.state.auth; const sessions = getSessionsInCookie(ctx); const signedInTeamIds = Object.keys(sessions); @@ -133,8 +136,15 @@ router.post("auth.info", auth(), async (ctx: APIContext) => { ]); // If the user did not _just_ sign in then we need to check if they continue - // to have access to the workspace they are signed into. - if (user.lastSignedInAt && user.lastSignedInAt < subHours(new Date(), 1)) { + // to have access to the workspace they are signed into. This only applies + // to SSO sessions - email and passkey logins don't have associated + // UserAuthentication records that need validation. + const isOAuthSession = !service || !NON_SSO_SERVICES.includes(service); + if ( + isOAuthSession && + user.lastSignedInAt && + user.lastSignedInAt < subHours(new Date(), 1) + ) { await new ValidateSSOAccessTask().schedule({ userId: user.id }); } diff --git a/server/routes/auth/index.ts b/server/routes/auth/index.ts index 01d850ddbc..ccec003e5f 100644 --- a/server/routes/auth/index.ts +++ b/server/routes/auth/index.ts @@ -31,8 +31,8 @@ void (async () => { })(); router.get("/redirect", authMiddleware(), async (ctx: APIContext) => { - const { user } = ctx.state.auth; - const jwtToken = user.getJwtToken(); + const { user, service } = ctx.state.auth; + const jwtToken = user.getJwtToken(undefined, service); if (jwtToken === ctx.state.auth.token) { throw AuthenticationError("Cannot extend token"); diff --git a/server/services/websockets.ts b/server/services/websockets.ts index bfffa812c6..0d6c6c6650 100644 --- a/server/services/websockets.ts +++ b/server/services/websockets.ts @@ -240,7 +240,7 @@ async function authenticate(socket: SocketWithAuth) { throw AuthenticationError("No access token"); } - const user = await getUserForJWT(accessToken); + const { user } = await getUserForJWT(accessToken); socket.client.user = user; return user; } diff --git a/server/types.ts b/server/types.ts index f95820c31a..0afa9a398d 100644 --- a/server/types.ts +++ b/server/types.ts @@ -52,9 +52,14 @@ export type AuthenticationResult = AccountProvisionerResult & { }; export type Authentication = { + /** The user associated with this session. */ user: User; + /** The token used for authenticating API requests, WebSocket connections, etc. */ token: string; + /** The type of authentication used to create this session (e.g., "api", "app", "oauth"). */ type?: AuthenticationType; + /** The authentication service used to create this session (e.g., "email", "passkeys", "google"). */ + service?: string; }; export type Pagination = { diff --git a/server/utils/authentication.ts b/server/utils/authentication.ts index c25765bcab..2e69f92262 100644 --- a/server/utils/authentication.ts +++ b/server/utils/authentication.ts @@ -126,15 +126,15 @@ export async function signIn( // stuck on the SSO screen. if (client === Client.Desktop) { ctx.redirect( - `${team.url}/desktop-redirect?token=${user.getTransferToken()}` + `${team.url}/desktop-redirect?token=${user.getTransferToken(service)}` ); } else { ctx.redirect( - `${team.url}/auth/redirect?token=${user.getTransferToken()}` + `${team.url}/auth/redirect?token=${user.getTransferToken(service)}` ); } } else { - ctx.cookies.set("accessToken", user.getJwtToken(expires), { + ctx.cookies.set("accessToken", user.getJwtToken(expires, service), { sameSite: "lax", expires, }); diff --git a/server/utils/jwt.ts b/server/utils/jwt.ts index 41819850a1..c43151eb56 100644 --- a/server/utils/jwt.ts +++ b/server/utils/jwt.ts @@ -24,10 +24,19 @@ export function getJWTPayload(token: string) { } } +/** + * Retrieves the user associated with a JWT token, validating the token's type and expiration. + * + * @param token The JWT token to validate and extract the user from. + * @param allowedTypes An array of allowed token types (default: ["session", "transfer"]). The token's type must be included in this array to be considered valid. + * @returns An object containing the user associated with the token and an optional service string if included in the token's payload. + * @throws AuthenticationError if the token is missing, invalid, expired, or if the token's type is not allowed. + * @throws UserSuspendedError if the user associated with the token is suspended. + */ export async function getUserForJWT( token: string, allowedTypes = ["session", "transfer"] -): Promise { +): Promise<{ user: User; service?: string }> { const payload = getJWTPayload(token); if (!allowedTypes.includes(payload.type)) { @@ -81,7 +90,10 @@ export async function getUserForJWT( throw AuthenticationError("Invalid token"); } - return user; + return { + user, + service: payload.service as string | undefined, + }; } export async function getUserForEmailSigninToken( diff --git a/server/utils/passport.ts b/server/utils/passport.ts index 0243be604d..807674741e 100644 --- a/server/utils/passport.ts +++ b/server/utils/passport.ts @@ -193,7 +193,8 @@ export async function getUserFromOAuthState(ctx: Context) { } try { - return await getUserForJWT(token); + const { user } = await getUserForJWT(token); + return user; } catch (_err) { return undefined; }