diff --git a/app/components/InputSelect.tsx b/app/components/InputSelect.tsx index c97b235f5d..be7bad7ee1 100644 --- a/app/components/InputSelect.tsx +++ b/app/components/InputSelect.tsx @@ -50,7 +50,7 @@ export type Item = { export type Option = Item | Separator; -type Props = { +type Props = Omit, "onChange"> & { /* Options to display in the select menu. */ options: Option[]; /* Current chosen value. */ diff --git a/app/components/OAuthClient/InputClientType.tsx b/app/components/OAuthClient/InputClientType.tsx new file mode 100644 index 0000000000..06b81bdcec --- /dev/null +++ b/app/components/OAuthClient/InputClientType.tsx @@ -0,0 +1,37 @@ +import * as React from "react"; +import { useTranslation } from "react-i18next"; +import { InputSelect } from "../InputSelect"; + +/** + * An input that allows a choice of OAuth client type. + */ +export const InputClientType = React.forwardRef( + ( + props: Omit, "options" | "label">, + ref: React.Ref + ) => { + const { t } = useTranslation(); + return ( + + ); + } +); diff --git a/app/components/OAuthClient/OAuthClientForm.tsx b/app/components/OAuthClient/OAuthClientForm.tsx index 38a461614b..e33b1a8834 100644 --- a/app/components/OAuthClient/OAuthClientForm.tsx +++ b/app/components/OAuthClient/OAuthClientForm.tsx @@ -11,6 +11,7 @@ import Input, { LabelText } from "~/components/Input"; import isCloudHosted from "~/utils/isCloudHosted"; import Switch from "../Switch"; import EventBoundary from "@shared/components/EventBoundary"; +import { InputClientType } from "./InputClientType"; export interface FormData { name: string; @@ -20,6 +21,7 @@ export interface FormData { avatarUrl: string; redirectUris: string[]; published: boolean; + clientType: "confidential" | "public"; } export const OAuthClientForm = observer(function OAuthClientForm_({ @@ -47,6 +49,7 @@ export const OAuthClientForm = observer(function OAuthClientForm_({ avatarUrl: oauthClient?.avatarUrl ?? "", redirectUris: oauthClient?.redirectUris ?? [], published: oauthClient?.published ?? false, + clientType: oauthClient?.clientType ?? "confidential", }, }); @@ -79,6 +82,17 @@ export const OAuthClientForm = observer(function OAuthClientForm_({ )} /> + ( + + )} + /> + + ( + + )} + /> + + - - - - - - - - - } - /> - - + readOnly + > + + + + + + + } + /> + + + )} ; export const OAuthClientsCreateSchema = BaseSchema.extend({ body: z.object({ + /** OAuth client type */ + clientType: z + .enum(OAuthClientValidation.clientTypes) + .default("confidential"), + /** OAuth client name */ name: z.string(), @@ -63,6 +68,9 @@ export const OAuthClientsUpdateSchema = BaseSchema.extend({ body: z.object({ id: z.string().uuid(), + /** OAuth client type */ + clientType: z.enum(OAuthClientValidation.clientTypes).optional(), + /** OAuth client name */ name: z.string().optional(), diff --git a/server/routes/oauth/index.test.ts b/server/routes/oauth/index.test.ts index a997226c14..d8d9427f64 100644 --- a/server/routes/oauth/index.test.ts +++ b/server/routes/oauth/index.test.ts @@ -1,7 +1,11 @@ import { Scope } from "@shared/types"; import { OAuthAuthentication } from "@server/models"; -import { buildOAuthAuthentication, buildUser } from "@server/test/factories"; -import { getTestServer } from "@server/test/support"; +import { + buildOAuthAuthentication, + buildOAuthClient, + buildUser, +} from "@server/test/factories"; +import { getTestServer, toFormData } from "@server/test/support"; const server = getTestServer(); @@ -45,3 +49,186 @@ describe("#oauth.revoke", () => { expect(res.status).toEqual(200); }); }); + +describe("#oauth.token", () => { + describe("refresh_token grant", () => { + it("should successfully refresh token for confidential client with client_secret", async () => { + const user = await buildUser(); + const client = await buildOAuthClient({ + teamId: user.teamId, + clientType: "confidential", + }); + const auth = await buildOAuthAuthentication({ + user, + scope: [Scope.Read], + oauthClientId: client.id, + }); + const refreshToken = auth.refreshToken; + + // Reload with oauthClient included + await auth.reload({ include: ["oauthClient"] }); + + const res = await server.post("/oauth/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: toFormData({ + grant_type: "refresh_token", + refresh_token: refreshToken, + client_id: auth.oauthClient.clientId, + client_secret: auth.oauthClient.clientSecret, + }), + }); + + expect(res.status).toEqual(200); + const body = await res.json(); + expect(body.access_token).toBeTruthy(); + expect(body.refresh_token).toBeTruthy(); + expect(body.token_type).toBe("Bearer"); + expect(body.expires_in).toBeGreaterThan(0); + }); + + it("should successfully refresh token for public client without client_secret", async () => { + const user = await buildUser(); + const client = await buildOAuthClient({ + teamId: user.teamId, + clientType: "public", + }); + const auth = await buildOAuthAuthentication({ + user, + scope: [Scope.Read], + oauthClientId: client.id, + }); + const refreshToken = auth.refreshToken; + + // Reload with oauthClient included + await auth.reload({ include: ["oauthClient"] }); + + const res = await server.post("/oauth/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: toFormData({ + grant_type: "refresh_token", + refresh_token: refreshToken, + client_id: auth.oauthClient.clientId, + }), + }); + + expect(res.status).toEqual(200); + const body = await res.json(); + expect(body.access_token).toBeTruthy(); + expect(body.refresh_token).toBeTruthy(); + expect(body.token_type).toBe("Bearer"); + expect(body.expires_in).toBeGreaterThan(0); + }); + + it("should successfully refresh token for public client with client_secret", async () => { + const user = await buildUser(); + const client = await buildOAuthClient({ + teamId: user.teamId, + clientType: "public", + }); + const auth = await buildOAuthAuthentication({ + user, + scope: [Scope.Read], + oauthClientId: client.id, + }); + const refreshToken = auth.refreshToken; + + // Reload with oauthClient included + await auth.reload({ include: ["oauthClient"] }); + + const res = await server.post("/oauth/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: toFormData({ + grant_type: "refresh_token", + refresh_token: refreshToken, + client_id: auth.oauthClient.clientId, + client_secret: auth.oauthClient.clientSecret, + }), + }); + + expect(res.status).toEqual(200); + const body = await res.json(); + expect(body.access_token).toBeTruthy(); + expect(body.refresh_token).toBeTruthy(); + expect(body.token_type).toBe("Bearer"); + expect(body.expires_in).toBeGreaterThan(0); + }); + + it("should error when refresh_token is missing for refresh_token grant", async () => { + const res = await server.post("/oauth/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: toFormData({ + grant_type: "refresh_token", + client_id: "test-client-id", + }), + }); + + expect(res.status).toEqual(400); + const body = await res.json(); + expect(body.error).toBeDefined(); + expect(body.error_description).toContain( + "Missing refresh_token for refresh_token grant type" + ); + }); + + it("should error when client_id is invalid", async () => { + const res = await server.post("/oauth/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: toFormData({ + grant_type: "refresh_token", + refresh_token: "invalid-refresh-token", + client_id: "test-client-id", + }), + }); + + expect(res.status).toEqual(400); + const body = await res.json(); + expect(body.error).toBeDefined(); + expect(body.error_description).toContain("Invalid client_id"); + }); + + it("should error when confidential client tries to refresh without client_secret", async () => { + const user = await buildUser(); + const client = await buildOAuthClient({ + teamId: user.teamId, + clientType: "confidential", + }); + const auth = await buildOAuthAuthentication({ + user, + scope: [Scope.Read], + oauthClientId: client.id, + }); + const refreshToken = auth.refreshToken; + + // Reload with oauthClient included + await auth.reload({ include: ["oauthClient"] }); + + const res = await server.post("/oauth/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: toFormData({ + grant_type: "refresh_token", + refresh_token: refreshToken, + client_id: auth.oauthClient.clientId, + }), + }); + + expect(res.status).toEqual(400); + const body = await res.json(); + expect(body.error).toBeDefined(); + expect(body.error_description).toContain( + "Missing client_secret for confidential client" + ); + }); + }); +}); diff --git a/server/routes/oauth/index.ts b/server/routes/oauth/index.ts index 14c1eee6e2..3e60385cce 100644 --- a/server/routes/oauth/index.ts +++ b/server/routes/oauth/index.ts @@ -22,6 +22,13 @@ const app = new Koa(); const router = new Router(); const oauth = new OAuth2Server({ model: OAuthInterface, + requireClientAuthentication: { + // Allow public clients (those without a client secret) to refresh without a client secret. + refresh_token: false, + }, + // Always revoke the used refresh token and issue a new one, see: + // https://www.rfc-editor.org/rfc/rfc6819#section-5.2.2.3 + alwaysIssueNewRefreshToken: true, }); router.post( @@ -70,8 +77,34 @@ router.post( router.post( "/token", + validate(T.TokenSchema), rateLimiter(RateLimiterStrategy.OneHundredPerHour), - async (ctx) => { + async (ctx: APIContext) => { + const grantType = ctx.input.body.grant_type; + const refreshToken = ctx.input.body.refresh_token; + const clientId = ctx.input.body.client_id; + const clientSecret = ctx.input.body.client_secret; + + // Because we disabled client authentication for refresh_token grant type at the library + // initialization, we need to manually enforce it here for confidential clients. + if (grantType === "refresh_token" && !clientSecret) { + if (!refreshToken) { + throw ValidationError( + "Missing refresh_token for refresh_token grant type" + ); + } + if (!clientId) { + throw ValidationError("Missing client_id for refresh_token grant type"); + } + const client = await OAuthClient.findByClientId(clientId); + if (!client) { + throw ValidationError("Invalid client_id"); + } + if (client.clientType === "confidential") { + throw ValidationError("Missing client_secret for confidential client"); + } + } + // Note: These objects are mutated by the OAuth2Server library const request = new OAuth2Server.Request(ctx.request); const response = new OAuth2Server.Response(ctx.response); diff --git a/server/routes/oauth/middlewares/oauthErrorHandler.ts b/server/routes/oauth/middlewares/oauthErrorHandler.ts index 820afe08a2..61092a89b5 100644 --- a/server/routes/oauth/middlewares/oauthErrorHandler.ts +++ b/server/routes/oauth/middlewares/oauthErrorHandler.ts @@ -31,9 +31,17 @@ export default function oauthErrorHandler() { return; } - ctx.status = err.code || 500; + ctx.status = err.status || err.statusCode || err.code || 500; + // Map common HTTP status codes to OAuth error types + let errorType = "server_error"; + if (ctx.status === 400) { + errorType = "invalid_request"; + } else if (ctx.status === 401) { + errorType = "invalid_client"; + } + ctx.body = { - error: err.name, + error: errorType, error_description: err.message, }; } diff --git a/server/routes/oauth/schema.ts b/server/routes/oauth/schema.ts index 3e9d10af5a..be87760720 100644 --- a/server/routes/oauth/schema.ts +++ b/server/routes/oauth/schema.ts @@ -1,6 +1,20 @@ import z from "zod"; import { BaseSchema } from "../api/schema"; +export const TokenSchema = BaseSchema.extend({ + body: z.object({ + grant_type: z.string(), + code: z.string().optional(), + redirect_uri: z.string().optional(), + client_id: z.string().optional(), + client_secret: z.string().optional(), + refresh_token: z.string().optional(), + scope: z.string().optional(), + }), +}); + +export type TokenReq = z.infer; + export const TokenRevokeSchema = BaseSchema.extend({ body: z.object({ token: z.string(), diff --git a/server/test/support.ts b/server/test/support.ts index e61c511b91..806e045d92 100644 --- a/server/test/support.ts +++ b/server/test/support.ts @@ -54,3 +54,20 @@ export function withAPIContext( } as APIContext); }); } + +/** + * Helper function to convert an object to form-urlencoded string. + * Useful for testing OAuth endpoints that expect application/x-www-form-urlencoded content type. + * + * @param obj Object to convert to form-urlencoded string + * @returns Form-urlencoded string representation of the object + */ +export function toFormData(obj: Record): string { + return Object.entries(obj) + .filter(([_, value]) => value !== undefined) + .map( + ([key, value]) => + `${encodeURIComponent(key)}=${encodeURIComponent(value)}` + ) + .join("&"); +} diff --git a/server/utils/oauth/OAuthInterface.ts b/server/utils/oauth/OAuthInterface.ts index c060ab1597..7f9e3830af 100644 --- a/server/utils/oauth/OAuthInterface.ts +++ b/server/utils/oauth/OAuthInterface.ts @@ -135,6 +135,7 @@ export const OAuthInterface: RefreshTokenModel & return { id: client.clientId, redirectUris: client.redirectUris, + clientType: client.clientType, databaseId: client.id, grants: this.grants, }; diff --git a/shared/i18n/locales/en_US/translation.json b/shared/i18n/locales/en_US/translation.json index 84ece13ee0..7c9a802ac2 100644 --- a/shared/i18n/locales/en_US/translation.json +++ b/shared/i18n/locales/en_US/translation.json @@ -353,6 +353,11 @@ "Unknown": "Unknown", "Mark all as read": "Mark all as read", "You're all caught up": "You're all caught up", + "Client type": "Client type", + "Confidential": "Confidential", + "Suitable for server-side applications": "Suitable for server-side applications", + "Public": "Public", + "Suitable for client-side or mobile applications": "Suitable for client-side or mobile applications", "Icon": "Icon", "OAuth client icon": "OAuth client icon", "My App": "My App", @@ -975,6 +980,7 @@ "Rotating the client secret will invalidate the current secret. Make sure to update any applications using these credentials.": "Rotating the client secret will invalidate the current secret. Make sure to update any applications using these credentials.", "Displayed to users when authorizing": "Displayed to users when authorizing", "Application icon": "Application icon", + "Confidential clients can securely store a secret": "Confidential clients can securely store a secret", "Developer information shown to users when authorizing": "Developer information shown to users when authorizing", "Developer name": "Developer name", "Developer URL": "Developer URL", diff --git a/shared/validations.ts b/shared/validations.ts index 1b2ba60fe0..d4dd983678 100644 --- a/shared/validations.ts +++ b/shared/validations.ts @@ -94,6 +94,9 @@ export const OAuthClientValidation = { /** The maximum length of an OAuth client redirect URI */ maxRedirectUriLength: 1000, + + /** The allowed OAuth client types */ + clientTypes: ["confidential", "public"] as const, }; export const RevisionValidation = {