diff --git a/server/models/Group.test.ts b/server/models/Group.test.ts new file mode 100644 index 0000000000..9710d4613f --- /dev/null +++ b/server/models/Group.test.ts @@ -0,0 +1,88 @@ +import { buildGroup, buildGroupUser, buildUser } from "@server/test/factories"; + +describe("Group", () => { + describe("memberCount", () => { + it("returns 0 for a group with no members", async () => { + const group = await buildGroup(); + expect(await group.memberCount).toEqual(0); + }); + + it("counts active members", async () => { + const group = await buildGroup(); + await buildGroupUser({ groupId: group.id, teamId: group.teamId }); + await buildGroupUser({ groupId: group.id, teamId: group.teamId }); + + expect(await group.memberCount).toEqual(2); + }); + + it("excludes suspended members", async () => { + const group = await buildGroup(); + await buildGroupUser({ groupId: group.id, teamId: group.teamId }); + + const suspendedUser = await buildUser({ + teamId: group.teamId, + suspendedAt: new Date(), + }); + await buildGroupUser({ + groupId: group.id, + teamId: group.teamId, + userId: suspendedUser.id, + }); + + expect(await group.memberCount).toEqual(1); + }); + + it("excludes soft-deleted members", async () => { + const group = await buildGroup(); + await buildGroupUser({ groupId: group.id, teamId: group.teamId }); + + const deletedUser = await buildUser({ teamId: group.teamId }); + await buildGroupUser({ + groupId: group.id, + teamId: group.teamId, + userId: deletedUser.id, + }); + await deletedUser.destroy(); + + expect(await group.memberCount).toEqual(1); + }); + + it("invalidates the cached count when a member is suspended", async () => { + const group = await buildGroup(); + const groupUser = await buildGroupUser({ + groupId: group.id, + teamId: group.teamId, + }); + await buildGroupUser({ groupId: group.id, teamId: group.teamId }); + + // Prime the cache. + expect(await group.memberCount).toEqual(2); + + const user = (await groupUser.$get("user"))!; + await user.update({ suspendedAt: new Date() }); + + expect(await group.memberCount).toEqual(1); + }); + + it("invalidates the cached count when a suspended member is restored", async () => { + const group = await buildGroup(); + const suspendedUser = await buildUser({ + teamId: group.teamId, + suspendedAt: new Date(), + }); + await buildGroupUser({ + groupId: group.id, + teamId: group.teamId, + userId: suspendedUser.id, + }); + await buildGroupUser({ groupId: group.id, teamId: group.teamId }); + + // Prime the cache (suspended user is excluded). + expect(await group.memberCount).toEqual(1); + + await suspendedUser.update({ suspendedAt: null }); + + expect(await group.memberCount).toEqual(2); + }); + }); +}); diff --git a/server/models/Group.ts b/server/models/Group.ts index 97c414fd2f..f57b19fa9f 100644 --- a/server/models/Group.ts +++ b/server/models/Group.ts @@ -120,7 +120,20 @@ class Group extends ParanoidModel< @BelongsToMany(() => User, () => GroupUser) users: User[]; - @CounterCache(() => GroupUser, { as: "members", foreignKey: "groupId" }) + @CounterCache(() => GroupUser, { + as: "members", + foreignKey: "groupId", + include: [ + { + association: "user", + required: true, + attributes: [], + where: { + suspendedAt: { [Op.is]: null }, + }, + }, + ], + }) memberCount: Promise; } diff --git a/server/models/User.ts b/server/models/User.ts index 34cd9ee4d7..b4cde648ab 100644 --- a/server/models/User.ts +++ b/server/models/User.ts @@ -60,6 +60,7 @@ import Attachment from "./Attachment"; import AuthenticationProvider from "./AuthenticationProvider"; import Collection from "./Collection"; import Group from "./Group"; +import GroupUser from "./GroupUser"; import Team from "./Team"; import UserAuthentication from "./UserAuthentication"; import UserMembership from "./UserMembership"; @@ -850,6 +851,50 @@ class User extends ParanoidModel< } } + // When a user's suspension state changes, invalidate the cached member count + // for every group they belong to so the count reflects only active members. + @AfterUpdate + static async invalidateGroupMemberCount( + model: User, + options: InstanceUpdateOptions> + ) { + if (!model.changed("suspendedAt")) { + return; + } + + const groupUsers = await GroupUser.findAll({ + attributes: ["groupId"], + where: { userId: model.id }, + transaction: options.transaction, + raw: true, + }); + + const groupIds = [ + ...new Set(groupUsers.map((groupUser) => groupUser.groupId)), + ]; + + if (!groupIds.length) { + return; + } + + const invalidate = async () => { + await Promise.all( + groupIds.map((groupId) => + CacheHelper.removeData( + RedisPrefixHelper.getCounterCacheKey("Group", "members", groupId) + ) + ) + ); + }; + + if (options.transaction) { + const transaction = options.transaction.parent || options.transaction; + transaction.afterCommit(invalidate); + } else { + await invalidate(); + } + } + @AfterUpdate static deletePreviousAvatar = async (model: User) => { const previousAvatarUrl = model.previous("avatarUrl"); diff --git a/server/models/decorators/CounterCache.ts b/server/models/decorators/CounterCache.ts index 4aa39b27d6..86e9953f79 100644 --- a/server/models/decorators/CounterCache.ts +++ b/server/models/decorators/CounterCache.ts @@ -1,8 +1,13 @@ import isNil from "lodash/isNil"; -import type { InferAttributes } from "sequelize"; +import type { + IncludeOptions, + InferAttributes, + Transaction, + WhereOptions, +} from "sequelize"; import type { ModelClassGetter } from "sequelize-typescript"; -import env from "@server/env"; import { CacheHelper } from "@server/utils/CacheHelper"; +import { RedisPrefixHelper } from "@server/utils/RedisPrefixHelper"; import type Model from "../base/Model"; type RelationOptions = { @@ -10,6 +15,10 @@ type RelationOptions = { as: string; /** The foreign key to use for the relationship query. */ foreignKey: string; + /** Optional include used in the count query for filtering through associations. */ + include?: IncludeOptions[]; + /** Optional additional where clause used in the count query. */ + where?: WhereOptions; }; /** @@ -25,57 +34,72 @@ export function CounterCache< options: RelationOptions ) { return function (target: InstanceType, _propertyKey: string) { - if (env.isTest) { - // No-op cache in test environment - return; - } const modelClass = classResolver() as typeof Model; - const cacheKeyPrefix = `count:${target.constructor.name}:${options.as}`; + const modelName = target.constructor.name; - // Add hooks after model is added to the sequelize instance - setImmediate(() => { - const recalculateCache = - (offset: number) => async (model: InstanceType) => { - const cacheKey = `${cacheKeyPrefix}:${String( - model[options.foreignKey as keyof typeof model] - )}`; + const buildCacheKey = (id: unknown) => + RedisPrefixHelper.getCounterCacheKey(modelName, options.as, String(id)); - const count = await modelClass.count({ - where: { - [options.foreignKey]: - model[options.foreignKey as keyof typeof model], - }, - }); - await CacheHelper.setData(cacheKey, count + offset); - }; + const computeCount = (id: unknown) => + modelClass.count({ + where: { [options.foreignKey]: id, ...(options.where ?? {}) }, + include: options.include, + distinct: !!options.include, + }); - // Because the transaction is not complete until after the response is sent, we need to - // offset the count by 1 to account for the record. TODO: Need to find a better way to handle - // this as a rollback would not decrement the count. - modelClass.addHook("afterCreate", recalculateCache(1)); - modelClass.addHook("afterDestroy", recalculateCache(-1)); - }); + const invalidate = async ( + model: InstanceType, + hookOptions?: { transaction?: Transaction | null } + ) => { + const cacheKey = buildCacheKey( + model[options.foreignKey as keyof typeof model] + ); + const remove = async () => { + await CacheHelper.removeData(cacheKey); + }; + + // Defer invalidation until after the transaction commits so that a + // rollback does not leave the cache out of sync, and so that a stale + // pre-commit count is not re-cached by a concurrent reader. Walk to + // the parent transaction when nested so the callback isn't lost when + // the savepoint releases without committing the outer transaction. + if (hookOptions?.transaction) { + const transaction = + hookOptions.transaction.parent || hookOptions.transaction; + transaction.afterCommit(remove); + } else { + await remove(); + } + }; + + // The model class is not added to a Sequelize instance until the database + // module is first imported, which is later than decorator evaluation. Poll + // until the model is ready, then register the hooks. Use unref() so the + // pending immediate does not keep the event loop alive in environments + // (such as tests) where the database is never initialized. + const registerHooks = () => { + if (!modelClass.sequelize) { + setImmediate(registerHooks).unref(); + return; + } + modelClass.addHook("afterCreate", invalidate); + modelClass.addHook("afterDestroy", invalidate); + }; + setImmediate(registerHooks).unref(); return { get() { - const cacheKey = `${cacheKeyPrefix}:${this.id}`; + const cacheKey = buildCacheKey(this.id); return CacheHelper.getData(cacheKey).then((value) => { if (!isNil(value)) { return value; } - // calculate and cache count - return modelClass - .count({ - where: { - [options.foreignKey]: this.id, - }, - }) - .then((count) => { - void CacheHelper.setData(cacheKey, count); - return count; - }); + return computeCount(this.id).then((count) => { + void CacheHelper.setData(cacheKey, count); + return count; + }); }); }, // eslint-disable-next-line @typescript-eslint/no-explicit-any -- TS rejects PropertyDescriptor as legacy decorator return type; descriptor is consumed by Sequelize at runtime. diff --git a/server/storage/redis.ts b/server/storage/redis.ts index 07580bbb2c..ec240015f6 100644 --- a/server/storage/redis.ts +++ b/server/storage/redis.ts @@ -115,6 +115,9 @@ export default class RedisAdapter extends Redis { }); }, env.REDIS_HEALTHCHECK_INTERVAL); + // Don't keep the Node event loop alive solely for the healthcheck. + healthcheck.unref(); + this.on("end", () => clearInterval(healthcheck)); } } diff --git a/server/test/setupMocks.js b/server/test/setupMocks.js index c87b668900..56ab55d1ed 100644 --- a/server/test/setupMocks.js +++ b/server/test/setupMocks.js @@ -6,7 +6,6 @@ jest.mock("ioredis", () => require("ioredis-mock")); // Mock other Redis-dependent modules jest.mock("@server/utils/MutexLock"); -jest.mock("@server/utils/CacheHelper"); // Mock AWS SDK signature module to prevent aws_logger open handle jest.mock("@aws-sdk/signature-v4-crt", () => ({})); diff --git a/server/utils/CacheHelper.ts b/server/utils/CacheHelper.ts index acd71aa7b9..f8d4caa9f7 100644 --- a/server/utils/CacheHelper.ts +++ b/server/utils/CacheHelper.ts @@ -126,6 +126,19 @@ export class CacheHelper { } } + /** + * Removes a single cached entry by key. + * + * @param key Cache key to remove. + */ + public static async removeData(key: string) { + try { + await Redis.defaultClient.del(key); + } catch (err) { + Logger.error(`Could not remove cached entry against ${key}`, err); + } + } + /** * Clears all cache data with the given prefix * diff --git a/server/utils/RedisPrefixHelper.ts b/server/utils/RedisPrefixHelper.ts index 24a112a9b7..9a0b2d0ce1 100644 --- a/server/utils/RedisPrefixHelper.ts +++ b/server/utils/RedisPrefixHelper.ts @@ -42,4 +42,21 @@ export class RedisPrefixHelper { public static getUserCollectionIdsKey(userId: string) { return `uc:${userId}`; } + + /** + * Gets key for caching the count of a relationship managed by the + * `CounterCache` decorator. + * + * @param modelName The owning model name (e.g. "Group"). + * @param relationName The relationship reference name (e.g. "members"). + * @param id The owning record id. + * @returns the cache key string. + */ + public static getCounterCacheKey( + modelName: string, + relationName: string, + id: string + ) { + return `count:${modelName}:${relationName}:${id}`; + } } diff --git a/server/utils/__mocks__/CacheHelper.ts b/server/utils/__mocks__/CacheHelper.ts deleted file mode 100644 index c7aa210f6b..0000000000 --- a/server/utils/__mocks__/CacheHelper.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { Day } from "@shared/utils/time"; -import type { CacheResult } from "../CacheHelper"; - -/** - * A Mock Helper class for server-side cache management - */ -export class CacheHelper { - // Default expiry time for cache data in seconds - private static defaultDataExpiry = Day.seconds; - - /** - * Mocked method that resolves with the callback result - */ - public static async getDataOrSet( - key: string, - callback: () => Promise | undefined>, - _expiry: number, - _lockTimeout?: number - ): Promise { - const result = await callback(); - if (result && typeof result === "object" && "data" in result) { - return (result as CacheResult).data; - } - return result as T | undefined; - } - - /** - * Mocked method that resolves with undefined - */ - public static async getData(_key: string): Promise { - return undefined; - } - - /** - * Mocked method that resolves with void - */ - public static async setData(_key: string, _data: T, _expiry?: number) { - return; - } - - /** - * Mocked method that resolves with void - */ - public static async clearData(_prefix: string) { - return; - } -}