Refactor of OAuth account linking flows (#12246)

* Refactor of OAuth account linking flows

* PR feedback
This commit is contained in:
Tom Moor
2026-05-02 18:54:38 -04:00
committed by GitHub
parent 8c716b173a
commit f50bb00b29
30 changed files with 590 additions and 84 deletions
@@ -2,21 +2,21 @@ import * as React from "react";
import { useTranslation } from "react-i18next";
import Button, { type Props } from "~/components/Button";
import useCurrentTeam from "~/hooks/useCurrentTeam";
import { generateOAuthStateNonce } from "~/utils/oauth";
import { redirectTo } from "~/utils/urls";
import { FigmaUtils } from "../../shared/FigmaUtils";
import { FigmaOAuthNonceCookie, FigmaUtils } from "../../shared/FigmaUtils";
export function FigmaConnectButton(props: Props<HTMLButtonElement>) {
const { t } = useTranslation();
const team = useCurrentTeam();
const handleConnect = React.useCallback(() => {
const nonce = generateOAuthStateNonce(FigmaOAuthNonceCookie);
redirectTo(FigmaUtils.authUrl({ state: { teamId: team.id, nonce } }));
}, [team.id]);
return (
<Button
onClick={() =>
redirectTo(FigmaUtils.authUrl({ state: { teamId: team.id } }))
}
neutral
{...props}
>
<Button onClick={handleConnect} neutral {...props}>
{t("Connect")}
</Button>
);
+44
View File
@@ -0,0 +1,44 @@
import { buildUser } from "@server/test/factories";
import { getTestServer } from "@server/test/support";
const server = getTestServer();
describe("#figma.callback", () => {
it("should reject callback when state nonce does not match cookie", async () => {
const user = await buildUser();
const state = JSON.stringify({
teamId: user.teamId,
nonce: "attacker-nonce",
});
const res = await server.get(
`/api/figma.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.error).toEqual("state_mismatch");
});
it("should reject callback when nonce is missing from state", async () => {
const user = await buildUser();
const state = JSON.stringify({ teamId: user.teamId });
const res = await server.get(
`/api/figma.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
it("should fail when state is not valid JSON", async () => {
const user = await buildUser();
const res = await server.get(
`/api/figma.callback?state=bad&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
});
+14 -2
View File
@@ -4,11 +4,16 @@ import * as T from "./schema";
import apexAuthRedirect from "@server/middlewares/apexAuthRedirect";
import type { APIContext } from "@server/types";
import validate from "@server/middlewares/validate";
import { FigmaUtils } from "plugins/figma/shared/FigmaUtils";
import {
FigmaOAuthNonceCookie,
FigmaUtils,
} from "plugins/figma/shared/FigmaUtils";
import { transaction } from "@server/middlewares/transaction";
import Logger from "@server/logging/Logger";
import { IntegrationService, IntegrationType } from "@shared/types";
import { ValidationError } from "@server/errors";
import { Integration, IntegrationAuthentication } from "@server/models";
import { verifyOAuthStateNonce } from "@server/utils/oauth";
import { addSeconds } from "date-fns";
import { Figma } from "../figma";
import UploadIntegrationLogoTask from "@server/queues/tasks/UploadIntegrationLogoTask";
@@ -30,7 +35,7 @@ router.get(
}),
transaction(),
async (ctx: APIContext<T.FigmaCallbackReq>) => {
const { code, error } = ctx.input.query;
const { code, error, state } = ctx.input.query;
// Check error after any sub-domain redirection. Otherwise, the user will be redirected to the root domain.
if (error) {
@@ -38,6 +43,13 @@ router.get(
return;
}
const parsedState = FigmaUtils.parseState(state);
if (!parsedState) {
throw ValidationError("Invalid state");
}
verifyOAuthStateNonce(ctx, FigmaOAuthNonceCookie, parsedState.nonce);
const { user } = ctx.state.auth;
const { transaction } = ctx.state;
+9 -2
View File
@@ -2,8 +2,11 @@ import queryString from "query-string";
import env from "@shared/env";
import { integrationSettingsPath } from "@shared/utils/routeHelpers";
export const FigmaOAuthNonceCookie = "figmaOAuthNonce";
export type OAuthState = {
teamId: string;
nonce: string;
};
export class FigmaUtils {
@@ -16,8 +19,12 @@ export class FigmaUtils {
private static settingsUrl = integrationSettingsPath("figma");
static parseState(state: string): OAuthState {
return JSON.parse(state);
static parseState(state: string): OAuthState | undefined {
try {
return JSON.parse(state);
} catch {
return undefined;
}
}
static successUrl() {
@@ -2,19 +2,21 @@ import * as React from "react";
import { useTranslation } from "react-i18next";
import Button, { type Props } from "~/components/Button";
import useCurrentTeam from "~/hooks/useCurrentTeam";
import { generateOAuthStateNonce } from "~/utils/oauth";
import { redirectTo } from "~/utils/urls";
import { GitHubUtils } from "../../shared/GitHubUtils";
import { GitHubOAuthNonceCookie, GitHubUtils } from "../../shared/GitHubUtils";
export function GitHubConnectButton(props: Props<HTMLButtonElement>) {
const { t } = useTranslation();
const team = useCurrentTeam();
const handleConnect = React.useCallback(() => {
const nonce = generateOAuthStateNonce(GitHubOAuthNonceCookie);
redirectTo(GitHubUtils.authUrl({ teamId: team.id, nonce }));
}, [team.id]);
return (
<Button
onClick={() => redirectTo(GitHubUtils.authUrl(team.id))}
neutral
{...props}
>
<Button onClick={handleConnect} neutral {...props}>
{t("Connect")}
</Button>
);
+45
View File
@@ -0,0 +1,45 @@
import { buildUser } from "@server/test/factories";
import { getTestServer } from "@server/test/support";
import { SetupAction } from "./schema";
const server = getTestServer();
describe("#github.callback", () => {
it("should reject callback when state nonce does not match cookie", async () => {
const user = await buildUser();
const state = JSON.stringify({
teamId: user.teamId,
nonce: "attacker-nonce",
});
const res = await server.get(
`/api/github.callback?state=${encodeURIComponent(
state
)}&code=123&setup_action=${SetupAction.install}&installation_id=1&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.error).toEqual("state_mismatch");
});
it("should reject callback when nonce is missing from state", async () => {
const user = await buildUser();
const state = JSON.stringify({ teamId: user.teamId });
const res = await server.get(
`/api/github.callback?state=${encodeURIComponent(
state
)}&code=123&setup_action=${SetupAction.install}&installation_id=1&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
it("should fail when state is not valid JSON", async () => {
const user = await buildUser();
const res = await server.get(
`/api/github.callback?state=bad&code=123&setup_action=${SetupAction.install}&installation_id=1&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
});
+13 -4
View File
@@ -2,6 +2,7 @@ import Router from "koa-router";
import find from "lodash/find";
import { IntegrationService, IntegrationType } from "@shared/types";
import { createContext } from "@server/context";
import { ValidationError } from "@server/errors";
import apexAuthRedirect from "@server/middlewares/apexAuthRedirect";
import auth from "@server/middlewares/authentication";
import { transaction } from "@server/middlewares/transaction";
@@ -9,7 +10,8 @@ import validate from "@server/middlewares/validate";
import validateWebhook from "@server/middlewares/validateWebhook";
import { IntegrationAuthentication, Integration } from "@server/models";
import type { APIContext } from "@server/types";
import { GitHubUtils } from "../../shared/GitHubUtils";
import { verifyOAuthStateNonce } from "@server/utils/oauth";
import { GitHubOAuthNonceCookie, GitHubUtils } from "../../shared/GitHubUtils";
import env from "../env";
import { GitHub } from "../github";
import GitHubWebhookTask from "../tasks/GitHubWebhookTask";
@@ -22,7 +24,7 @@ router.get(
auth({ optional: true }),
validate(T.GitHubCallbackSchema),
apexAuthRedirect<T.GitHubCallbackReq>({
getTeamId: (ctx) => ctx.input.query.state,
getTeamId: (ctx) => GitHubUtils.parseState(ctx.input.query.state)?.teamId,
getRedirectPath: (ctx, team) =>
GitHubUtils.callbackUrl({
baseUrl: team.url,
@@ -34,7 +36,7 @@ router.get(
async (ctx: APIContext<T.GitHubCallbackReq>) => {
const {
code,
state: teamId,
state,
error,
installation_id: installationId,
setup_action: setupAction,
@@ -52,7 +54,14 @@ router.get(
return;
}
const client = await GitHub.authenticateAsUser(code!, teamId);
const parsedState = GitHubUtils.parseState(state);
if (!parsedState) {
throw ValidationError("Invalid state");
}
verifyOAuthStateNonce(ctx, GitHubOAuthNonceCookie, parsedState.nonce);
const client = await GitHub.authenticateAsUser(code!, state);
const installationsByUser = await client.requestAppInstallations();
const installation = find(
installationsByUser,
+1 -1
View File
@@ -12,7 +12,7 @@ export const GitHubCallbackSchema = BaseSchema.extend({
query: z
.object({
code: z.string().nullish(),
state: z.uuid().nullish(),
state: z.string(),
error: z.string().nullish(),
installation_id: z.coerce.number().optional(),
setup_action: z.enum(SetupAction),
+17 -2
View File
@@ -2,6 +2,13 @@ import queryString from "query-string";
import env from "@shared/env";
import { integrationSettingsPath } from "@shared/utils/routeHelpers";
export const GitHubOAuthNonceCookie = "githubOAuthNonce";
export type OAuthState = {
teamId: string;
nonce: string;
};
export class GitHubUtils {
public static clientId = env.GITHUB_CLIENT_ID;
@@ -31,16 +38,24 @@ export class GitHubUtils {
: `${baseUrl}/api/github.callback`;
}
static authUrl(state: string): string {
static authUrl(state: OAuthState): string {
const baseUrl = `https://github.com/apps/${env.GITHUB_APP_NAME}/installations/new`;
const params = {
client_id: this.clientId,
redirect_uri: this.callbackUrl(),
state,
state: JSON.stringify(state),
};
return `${baseUrl}?${queryString.stringify(params)}`;
}
static parseState(state: string): OAuthState | undefined {
try {
return JSON.parse(state);
} catch {
return undefined;
}
}
static installRequestUrl(): string {
return `${this.url}?install_request=true`;
}
+44
View File
@@ -0,0 +1,44 @@
import { buildUser } from "@server/test/factories";
import { getTestServer } from "@server/test/support";
const server = getTestServer();
describe("#gitlab.callback", () => {
it("should reject callback when state nonce does not match cookie", async () => {
const user = await buildUser();
const state = JSON.stringify({
teamId: user.teamId,
nonce: "attacker-nonce",
});
const res = await server.get(
`/api/gitlab.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.error).toEqual("state_mismatch");
});
it("should reject callback when nonce is missing from state", async () => {
const user = await buildUser();
const state = JSON.stringify({ teamId: user.teamId });
const res = await server.get(
`/api/gitlab.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
it("should fail when state is not valid JSON", async () => {
const user = await buildUser();
const res = await server.get(
`/api/gitlab.callback?state=bad&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
});
+22 -5
View File
@@ -2,6 +2,7 @@ import Router from "koa-router";
import { Op } from "sequelize";
import { IntegrationService, IntegrationType } from "@shared/types";
import { createContext } from "@server/context";
import { ValidationError } from "@server/errors";
import apexAuthRedirect from "@server/middlewares/apexAuthRedirect";
import auth from "@server/middlewares/authentication";
import { transaction } from "@server/middlewares/transaction";
@@ -10,14 +11,18 @@ import validateWebhook from "@server/middlewares/validateWebhook";
import { IntegrationAuthentication, Integration } from "@server/models";
import { authorize } from "@server/policies";
import type { APIContext } from "@server/types";
import {
generateOAuthStateNonce,
verifyOAuthStateNonce,
} from "@server/utils/oauth";
import { validateUrlNotPrivate } from "@server/utils/url";
import { addSeconds } from "date-fns";
import Logger from "@server/logging/Logger";
import { GitLabUtils } from "../../shared/GitLabUtils";
import { GitLabOAuthNonceCookie, GitLabUtils } from "../../shared/GitLabUtils";
import { GitLab } from "../gitlab";
import env from "../env";
import GitLabWebhookTask from "../tasks/GitLabWebhookTask";
import * as T from "../schema";
import * as T from "./schema";
const router = new Router();
@@ -111,7 +116,12 @@ router.post(
}
}
const redirectUrl = GitLabUtils.authUrl(user.teamId, url, clientId);
const nonce = generateOAuthStateNonce(ctx, GitLabOAuthNonceCookie);
const redirectUrl = GitLabUtils.authUrl(
{ teamId: user.teamId, nonce },
url,
clientId
);
ctx.body = {
data: { redirectUrl },
};
@@ -123,7 +133,7 @@ router.get(
auth({ optional: true }),
validate(T.GitLabCallbackSchema),
apexAuthRedirect<T.GitLabCallbackReq>({
getTeamId: (ctx) => ctx.input.query.state,
getTeamId: (ctx) => GitLabUtils.parseState(ctx.input.query.state)?.teamId,
getRedirectPath: (ctx, team) =>
GitLabUtils.callbackUrl({
baseUrl: team.url,
@@ -133,7 +143,7 @@ router.get(
}),
transaction(),
async (ctx: APIContext<T.GitLabCallbackReq>) => {
const { code, error } = ctx.input.query;
const { code, error, state } = ctx.input.query;
const { user } = ctx.state.auth;
const { transaction } = ctx.state;
@@ -142,6 +152,13 @@ router.get(
return;
}
const parsedState = GitLabUtils.parseState(state);
if (!parsedState) {
throw ValidationError("Invalid state");
}
verifyOAuthStateNonce(ctx, GitLabOAuthNonceCookie, parsedState.nonce);
try {
// Check for a pending IntegrationAuthentication with custom credentials
const pendingAuth = await IntegrationAuthentication.findOne({
@@ -6,7 +6,7 @@ export const GitLabCallbackSchema = BaseSchema.extend({
query: z
.object({
code: z.string().nullish(),
state: z.string().uuid().nullish(),
state: z.string(),
error: z.string().nullish(),
})
.refine((req) => !(isEmpty(req.code) && isEmpty(req.error)), {
+24 -3
View File
@@ -2,6 +2,13 @@ import env from "@shared/env";
import { integrationSettingsPath } from "@shared/utils/routeHelpers";
import { UnfurlResourceType } from "@shared/types";
export const GitLabOAuthNonceCookie = "gitlabOAuthNonce";
export type OAuthState = {
teamId: string;
nonce: string;
};
export class GitLabUtils {
public static defaultGitlabUrl = "https://gitlab.com";
@@ -67,13 +74,13 @@ export class GitLabUtils {
/**
* Generates the authorization URL for GitLab OAuth.
*
* @param state - A unique state string to prevent CSRF attacks.
* @param state - The OAuth state with teamId for routing and nonce for CSRF.
* @param customUrl - Optional custom GitLab URL from integration settings.
* @param customClientId - Optional custom OAuth client ID from integration settings.
* @returns The full URL to redirect the user to GitLab's OAuth authorization page.
*/
public static authUrl(
state: string,
state: OAuthState,
customUrl?: string,
customClientId?: string
): string {
@@ -81,13 +88,27 @@ export class GitLabUtils {
client_id: customClientId || env.GITLAB_CLIENT_ID,
redirect_uri: this.callbackUrl(),
response_type: "code",
state,
state: JSON.stringify(state),
scope: "read_api read_user",
});
return `${this.getOauthUrl(customUrl)}/authorize?${params.toString()}`;
}
/**
* Parses an OAuth state string from a GitLab callback.
*
* @param state - The state string carried in the callback query.
* @returns The parsed OAuth state.
*/
public static parseState(state: string): OAuthState | undefined {
try {
return JSON.parse(state);
} catch {
return undefined;
}
}
/**
* Generates the installation request URL.
*
@@ -2,21 +2,21 @@ import * as React from "react";
import { useTranslation } from "react-i18next";
import Button, { type Props } from "~/components/Button";
import useCurrentTeam from "~/hooks/useCurrentTeam";
import { generateOAuthStateNonce } from "~/utils/oauth";
import { redirectTo } from "~/utils/urls";
import { LinearUtils } from "../../shared/LinearUtils";
import { LinearOAuthNonceCookie, LinearUtils } from "../../shared/LinearUtils";
export function LinearConnectButton(props: Props<HTMLButtonElement>) {
const { t } = useTranslation();
const team = useCurrentTeam();
const handleConnect = React.useCallback(() => {
const nonce = generateOAuthStateNonce(LinearOAuthNonceCookie);
redirectTo(LinearUtils.authUrl({ state: { teamId: team.id, nonce } }));
}, [team.id]);
return (
<Button
onClick={() =>
redirectTo(LinearUtils.authUrl({ state: { teamId: team.id } }))
}
neutral
{...props}
>
<Button onClick={handleConnect} neutral {...props}>
{t("Connect")}
</Button>
);
+44
View File
@@ -0,0 +1,44 @@
import { buildUser } from "@server/test/factories";
import { getTestServer } from "@server/test/support";
const server = getTestServer();
describe("#linear.callback", () => {
it("should reject callback when state nonce does not match cookie", async () => {
const user = await buildUser();
const state = JSON.stringify({
teamId: user.teamId,
nonce: "attacker-nonce",
});
const res = await server.get(
`/api/linear.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.error).toEqual("state_mismatch");
});
it("should reject callback when nonce is missing from state", async () => {
const user = await buildUser();
const state = JSON.stringify({ teamId: user.teamId });
const res = await server.get(
`/api/linear.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
it("should fail when state is not valid JSON", async () => {
const user = await buildUser();
const res = await server.get(
`/api/linear.callback?state=bad&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
});
+14 -2
View File
@@ -1,5 +1,6 @@
import Router from "koa-router";
import { IntegrationService, IntegrationType } from "@shared/types";
import { ValidationError } from "@server/errors";
import Logger from "@server/logging/Logger";
import apexAuthRedirect from "@server/middlewares/apexAuthRedirect";
import auth from "@server/middlewares/authentication";
@@ -7,10 +8,14 @@ import { transaction } from "@server/middlewares/transaction";
import validate from "@server/middlewares/validate";
import { IntegrationAuthentication, Integration } from "@server/models";
import type { APIContext } from "@server/types";
import { verifyOAuthStateNonce } from "@server/utils/oauth";
import { Linear } from "../linear";
import UploadIntegrationLogoTask from "@server/queues/tasks/UploadIntegrationLogoTask";
import * as T from "./schema";
import { LinearUtils } from "plugins/linear/shared/LinearUtils";
import {
LinearOAuthNonceCookie,
LinearUtils,
} from "plugins/linear/shared/LinearUtils";
import { addSeconds } from "date-fns";
const router = new Router();
@@ -32,7 +37,7 @@ router.get(
}),
transaction(),
async (ctx: APIContext<T.LinearCallbackReq>) => {
const { code, error } = ctx.input.query;
const { code, error, state } = ctx.input.query;
const { user } = ctx.state.auth;
const { transaction } = ctx.state;
@@ -42,6 +47,13 @@ router.get(
return;
}
const parsedState = LinearUtils.parseState(state);
if (!parsedState) {
throw ValidationError("Invalid state");
}
verifyOAuthStateNonce(ctx, LinearOAuthNonceCookie, parsedState.nonce);
try {
// validation middleware ensures that code is non-null at this point.
const oauth = await Linear.oauthAccess(code!);
+9 -2
View File
@@ -2,8 +2,11 @@ import queryString from "query-string";
import env from "@shared/env";
import { integrationSettingsPath } from "@shared/utils/routeHelpers";
export const LinearOAuthNonceCookie = "linearOAuthNonce";
export type OAuthState = {
teamId: string;
nonce: string;
};
export class LinearUtils {
@@ -15,8 +18,12 @@ export class LinearUtils {
private static settingsUrl = integrationSettingsPath("linear");
static parseState(state: string): OAuthState {
return JSON.parse(state);
static parseState(state: string): OAuthState | undefined {
try {
return JSON.parse(state);
} catch {
return undefined;
}
}
static successUrl() {
+8 -3
View File
@@ -9,8 +9,9 @@ import Button from "~/components/Button";
import useCurrentTeam from "~/hooks/useCurrentTeam";
import useQuery from "~/hooks/useQuery";
import useStores from "~/hooks/useStores";
import { generateOAuthStateNonce } from "~/utils/oauth";
import { redirectTo } from "~/utils/urls";
import { NotionUtils } from "../shared/NotionUtils";
import { NotionOAuthNonceCookie, NotionUtils } from "../shared/NotionUtils";
import { ImportDialog } from "./components/ImportDialog";
export const Notion = observer(() => {
@@ -22,7 +23,6 @@ export const Notion = observer(() => {
const queryParams = useQuery();
const appName = env.APP_NAME;
const authUrl = NotionUtils.authUrl({ state: { teamId: team.id } });
const service = queryParams.get("service");
const oauthSuccess = queryParams.get("success") === "";
@@ -88,10 +88,15 @@ export const Notion = observer(() => {
}
}, [t, appName, oauthError]);
const handleConnect = React.useCallback(() => {
const nonce = generateOAuthStateNonce(NotionOAuthNonceCookie);
redirectTo(NotionUtils.authUrl({ state: { teamId: team.id, nonce } }));
}, [team.id]);
return (
<Button
type="submit"
onClick={() => redirectTo(authUrl)}
onClick={handleConnect}
disabled={!env.NOTION_CLIENT_ID}
neutral
>
+44
View File
@@ -0,0 +1,44 @@
import { buildUser } from "@server/test/factories";
import { getTestServer } from "@server/test/support";
const server = getTestServer();
describe("#notion.callback", () => {
it("should reject callback when state nonce does not match cookie", async () => {
const user = await buildUser();
const state = JSON.stringify({
teamId: user.teamId,
nonce: "attacker-nonce",
});
const res = await server.get(
`/api/notion.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.error).toEqual("state_mismatch");
});
it("should reject callback when nonce is missing from state", async () => {
const user = await buildUser();
const state = JSON.stringify({ teamId: user.teamId });
const res = await server.get(
`/api/notion.callback?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
it("should fail when state is not valid JSON", async () => {
const user = await buildUser();
const res = await server.get(
`/api/notion.callback?state=bad&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
});
+14 -2
View File
@@ -1,14 +1,19 @@
import Router from "koa-router";
import { IntegrationService, IntegrationType } from "@shared/types";
import { ValidationError } from "@server/errors";
import apexAuthRedirect from "@server/middlewares/apexAuthRedirect";
import auth from "@server/middlewares/authentication";
import { transaction } from "@server/middlewares/transaction";
import validate from "@server/middlewares/validate";
import { Integration, IntegrationAuthentication } from "@server/models";
import type { APIContext } from "@server/types";
import { verifyOAuthStateNonce } from "@server/utils/oauth";
import { NotionClient } from "../notion";
import * as T from "./schema";
import { NotionUtils } from "plugins/notion/shared/NotionUtils";
import {
NotionOAuthNonceCookie,
NotionUtils,
} from "plugins/notion/shared/NotionUtils";
const router = new Router();
@@ -27,7 +32,7 @@ router.get(
}),
transaction(),
async (ctx: APIContext<T.NotionCallbackReq>) => {
const { code, error } = ctx.input.query;
const { code, error, state } = ctx.input.query;
const { user } = ctx.state.auth;
const { transaction } = ctx.state;
@@ -37,6 +42,13 @@ router.get(
return;
}
const parsedState = NotionUtils.parseState(state);
if (!parsedState) {
throw ValidationError("Invalid state");
}
verifyOAuthStateNonce(ctx, NotionOAuthNonceCookie, parsedState.nonce);
// validation middleware ensures that code is non-null at this point.
const data = await NotionClient.oauthAccess(code!);
+9 -2
View File
@@ -3,8 +3,11 @@ import env from "@shared/env";
import { IntegrationService } from "@shared/types";
import { settingsPath } from "@shared/utils/routeHelpers";
export const NotionOAuthNonceCookie = "notionOAuthNonce";
export type OAuthState = {
teamId: string;
nonce: string;
};
export class NotionUtils {
@@ -13,8 +16,12 @@ export class NotionUtils {
private static settingsUrl = settingsPath("import");
static parseState(state: string): OAuthState {
return JSON.parse(state);
static parseState(state: string): OAuthState | undefined {
try {
return JSON.parse(state);
} catch {
return undefined;
}
}
static successUrl(integrationId: string) {
+9 -13
View File
@@ -101,11 +101,9 @@ function Slack() {
/>
) : (
<SlackButton
type={IntegrationType.LinkedAccount}
state={{ teamId: team.id }}
redirectUri={SlackUtils.connectUrl()}
state={SlackUtils.createState(
team.id,
IntegrationType.LinkedAccount
)}
label={t("Connect")}
/>
)}
@@ -141,12 +139,10 @@ function Slack() {
/>
) : (
<SlackButton
type={IntegrationType.Command}
scopes={["commands", "links:read", "links:write"]}
state={{ teamId: team.id }}
redirectUri={SlackUtils.connectUrl()}
state={SlackUtils.createState(
team.id,
IntegrationType.Command
)}
icon={<SlackIcon />}
/>
)}
@@ -183,13 +179,13 @@ function Slack() {
image={<CollectionIcon collection={collection} />}
actions={
<SlackButton
type={IntegrationType.Post}
scopes={["incoming-webhook"]}
state={{
teamId: team.id,
collectionId: collection.id,
}}
redirectUri={SlackUtils.connectUrl()}
state={SlackUtils.createState(
team.id,
IntegrationType.Post,
{ collectionId: collection.id }
)}
label={t("Connect")}
/>
}
@@ -1,21 +1,35 @@
import * as React from "react";
import { useTranslation } from "react-i18next";
import type { IntegrationType } from "@shared/types";
import Button from "~/components/Button";
import { SlackUtils } from "../../shared/SlackUtils";
import { generateOAuthStateNonce } from "~/utils/oauth";
import { redirectTo } from "~/utils/urls";
import { SlackOAuthNonceCookie, SlackUtils } from "../../shared/SlackUtils";
type Props = {
type: IntegrationType;
scopes?: string[];
redirectUri: string;
state: { teamId: string; collectionId?: string };
redirectUri?: string;
icon?: React.ReactNode;
state?: string;
label?: string;
};
function SlackButton({ state = "", scopes, redirectUri, label, icon }: Props) {
function SlackButton({
type,
scopes,
state: stateData,
redirectUri,
label,
icon,
}: Props) {
const { t } = useTranslation();
const handleClick = () => {
window.location.href = SlackUtils.authUrl(state, scopes, redirectUri);
const nonce = generateOAuthStateNonce(SlackOAuthNonceCookie);
const { teamId, ...rest } = stateData;
const state = SlackUtils.createState(teamId, type, { nonce, ...rest });
redirectTo(SlackUtils.authUrl(state, scopes, redirectUri));
};
return (
+34
View File
@@ -1,3 +1,4 @@
import { IntegrationType } from "@shared/types";
import { buildUser } from "@server/test/factories";
import { getTestServer } from "@server/test/support";
import { parseEmail } from "@shared/utils/email";
@@ -31,6 +32,39 @@ describe("#slack.post", () => {
expect(res.status).toEqual(400);
expect(body.message).toEqual("query: one of code or error is required");
});
it("should reject callback when state nonce does not match cookie", async () => {
const user = await buildUser();
const state = JSON.stringify({
type: IntegrationType.LinkedAccount,
teamId: user.teamId,
nonce: "attacker-nonce",
});
const res = await server.get(
`/auth/slack.post?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.error).toEqual("state_mismatch");
});
it("should reject callback when nonce is missing from state", async () => {
const user = await buildUser();
const state = JSON.stringify({
type: IntegrationType.LinkedAccount,
teamId: user.teamId,
});
const res = await server.get(
`/auth/slack.post?state=${encodeURIComponent(
state
)}&code=123&token=${user.getJwtToken()}`,
{ redirect: "manual" }
);
expect(res.status).toEqual(400);
});
});
describe("Slack authentication domain extraction", () => {
+12 -7
View File
@@ -19,6 +19,7 @@ import {
import { authorize } from "@server/policies";
import { sequelize } from "@server/storage/database";
import type { APIContext, AuthenticationResult } from "@server/types";
import { verifyOAuthStateNonce } from "@server/utils/oauth";
import {
getClientFromOAuthState,
getTeamFromContext,
@@ -29,7 +30,10 @@ import { parseEmail } from "@shared/utils/email";
import env from "../env";
import * as Slack from "../slack";
import * as T from "./schema";
import { SlackUtils } from "plugins/slack/shared/SlackUtils";
import {
SlackUtils,
SlackOAuthNonceCookie,
} from "plugins/slack/shared/SlackUtils";
import { createContext } from "@server/context";
type SlackProfile = Profile & {
@@ -155,19 +159,20 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
return;
}
let parsedState;
try {
parsedState = SlackUtils.parseState<{
collectionId: string;
}>(state);
} catch (_err) {
const parsedState = SlackUtils.parseState(state);
if (!parsedState) {
throw ValidationError("Invalid state");
}
verifyOAuthStateNonce(ctx, SlackOAuthNonceCookie, parsedState.nonce);
const { collectionId, type } = parsedState;
switch (type) {
case IntegrationType.Post: {
if (!collectionId) {
throw ValidationError("collectionId is required");
}
const collection = await Collection.findByPk(collectionId, {
userId: user.id,
});
+15 -4
View File
@@ -2,6 +2,15 @@ import env from "@shared/env";
import type { IntegrationType } from "@shared/types";
import { integrationSettingsPath } from "@shared/utils/routeHelpers";
export const SlackOAuthNonceCookie = "slackOAuthNonce";
export type OAuthState = {
teamId: string;
type: IntegrationType;
nonce: string;
collectionId?: string;
};
export class SlackUtils {
private static authBaseUrl = "https://slack.com/oauth/authorize";
@@ -27,10 +36,12 @@ export class SlackUtils {
* @param state The state string
* @returns The parsed state
*/
static parseState<T>(
state: string
): { teamId: string; type: IntegrationType } & T {
return JSON.parse(state);
static parseState(state: string): OAuthState | undefined {
try {
return JSON.parse(state);
} catch {
return undefined;
}
}
static callbackUrl(