diff --git a/backend/middlewares/__test__/expressGraphqlCache.test.ts b/backend/middlewares/__test__/expressGraphqlCache.test.ts new file mode 100644 index 000000000..fea829ed1 --- /dev/null +++ b/backend/middlewares/__test__/expressGraphqlCache.test.ts @@ -0,0 +1,291 @@ +import type { Request, Response } from "express" +import type { Logger } from "winston" + +import createMiddleware, { + invalidateAllGraphqlCachedQueries, +} from "../expressGraphqlCache" + +const mockRedisClient = { + isReady: true, + get: jest.fn(), + set: jest.fn(), + del: jest.fn(), + scan: jest.fn(), +} + +jest.mock("../../services/redis", () => { + return { + __esModule: true, + default: () => mockRedisClient, + } +}) + +jest.mock("../../server", () => ({ + GRAPHQL_ENDPOINT_PATH: "/graphql", +})) + +const makeLogger = (): Logger => + ({ + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }) as unknown as Logger + +type MockRes = Response & { + _sent?: any + _headers: Record + _status?: number +} + +const makeReq = (overrides: Partial = {}): Request => + ({ + method: "POST", + originalUrl: "/graphql", + baseUrl: "", + path: "/graphql", + headers: {}, + body: { query: "{ me { id } }", variables: {}, operationName: null }, + ...overrides, + }) as unknown as Request + +const makeRes = (): MockRes => { + const r: Partial = { + _headers: {}, + statusCode: 200, + setHeader(key: string, value: string) { + r._headers![key.toLowerCase()] = value + return undefined as any + }, + getHeader(key: string) { + return r._headers![key.toLowerCase()] + }, + status(code: number) { + r.statusCode = code + r._status = code + return r as MockRes + }, + send: jest.fn(function (this: MockRes, body?: any) { + r._sent = body + return this + }), + } + return r as MockRes +} + +const makeNext = () => jest.fn() + +describe("GraphQL response cache middleware", () => { + beforeEach(() => { + jest.clearAllMocks() + mockRedisClient.isReady = true + }) + + test("serves from cache on HIT and does not call next()", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + ;(mockRedisClient.get as jest.Mock).mockResolvedValueOnce( + JSON.stringify({ data: { me: { id: "1" } } }), + ) + + const req = makeReq() + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + expect(mockRedisClient.get).toHaveBeenCalledTimes(1) + expect(res._status ?? res.statusCode).toBe(200) + expect(res._headers["content-type"]).toMatch(/application\/json/) + expect(res._headers["x-cache"]).toBe("HIT") + expect((res.send as jest.Mock).mock.calls[0][0]).toEqual( + JSON.stringify({ data: { me: { id: "1" } } }), + ) + expect(next).not.toHaveBeenCalled() + }) + + test("on MISS wraps res.send, stores only successful JSON and sets X-Cache=MISS", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + mockRedisClient.isReady = true + ;(mockRedisClient.get as jest.Mock).mockResolvedValueOnce(null) + ;(mockRedisClient.set as jest.Mock).mockResolvedValue(undefined) + + const req = makeReq() + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + expect(next).toHaveBeenCalledTimes(1) + + res.setHeader("Content-Type", "application/json; charset=utf-8") + res.status(200) + const body = { data: { me: { id: "2" } } } + res.send(body) + + expect(mockRedisClient.set).toHaveBeenCalledTimes(1) + const [, payload, opts] = (mockRedisClient.set as jest.Mock).mock.calls[0] + expect(typeof payload).toBe("string") + expect(opts).toEqual(expect.objectContaining({ EX: 60 * 60 })) + expect(res._headers["x-cache"]).toBe("MISS") + }) + + test("does not cache if Authorization header present", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + + const req = makeReq({ headers: { authorization: "Bearer x" } }) + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + expect(mockRedisClient.get).not.toHaveBeenCalled() + expect(next).toHaveBeenCalledTimes(1) + }) + + test("does not cache non-query operations (mutation)", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + + const req = makeReq({ + body: { query: "mutation { doThing }" }, + }) + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + expect(mockRedisClient.get).not.toHaveBeenCalled() + expect(next).toHaveBeenCalledTimes(1) + }) + + test("does not cache if HTTP status is not 2xx", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + ;(mockRedisClient.get as jest.Mock).mockResolvedValueOnce(null) + + const req = makeReq() + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + res.setHeader("Content-Type", "application/json") + res.status(500) + res.send({ errors: [{ message: "boom" }] }) + + expect(mockRedisClient.set).not.toHaveBeenCalled() + expect(res._headers["x-cache"]).toBe("BYPASS") + }) + + test("does not cache if GraphQL response contains errors[] even with 200", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + ;(mockRedisClient.get as jest.Mock).mockResolvedValueOnce(null) + + const req = makeReq() + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + res.setHeader("Content-Type", "application/json") + res.status(200) + res.send({ data: null, errors: [{ message: "nope" }] }) + + expect(mockRedisClient.set).not.toHaveBeenCalled() + expect(res._headers["x-cache"]).toBe("BYPASS") + }) + + test("does not cache if response is non-JSON", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + ;(mockRedisClient.get as jest.Mock).mockResolvedValueOnce(null) + + const req = makeReq() + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + res.setHeader("Content-Type", "text/plain") + res.status(200) + res.send("ok") + + expect(mockRedisClient.set).not.toHaveBeenCalled() + expect(res._headers["x-cache"]).toBe("BYPASS") + }) + + test("skips when redis is not ready", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + mockRedisClient.isReady = false + + const req = makeReq() + const res = makeRes() + const next = makeNext() + + await mw(req, res, next) + + expect(mockRedisClient.get).not.toHaveBeenCalled() + expect(next).toHaveBeenCalledTimes(1) + }) + + test("skips when path/method do not match GraphQL POST", async () => { + const logger = makeLogger() + const mw = createMiddleware(logger) + + const req1 = makeReq({ method: "GET" as any }) + const res1 = makeRes() + const next1 = makeNext() + + await mw(req1, res1, next1) + expect(next1).toHaveBeenCalledTimes(1) + + const req2 = makeReq({ originalUrl: "/not-graphql", path: "/not-graphql" }) + const res2 = makeRes() + const next2 = makeNext() + + await mw(req2, res2, next2) + expect(next2).toHaveBeenCalledTimes(1) + }) +}) + +describe("invalidateAllGraphqlCachedQueries", () => { + beforeEach(() => { + jest.clearAllMocks() + mockRedisClient.isReady = true + }) + + test("deletes all keys with the middleware prefix and returns count", async () => { + const logger = makeLogger() + const keys = [ + "express-graphql-response-cache:abc", + "express-graphql-response-cache:def", + "express-graphql-response-cache:ghi", + ] + ;(mockRedisClient.scan as jest.Mock).mockResolvedValueOnce({ + cursor: "0", + keys: keys, + }) + ;(mockRedisClient.del as jest.Mock).mockResolvedValue(1) + + const count = await invalidateAllGraphqlCachedQueries(logger as any) + + expect(mockRedisClient.scan).toHaveBeenCalledWith("0", { + MATCH: "express-graphql-response-cache:*", + COUNT: 1000, + }) + expect(mockRedisClient.del).toHaveBeenCalledTimes(keys.length) + expect(count).toBe(keys.length) + }) + + test("throws if redis is not ready", async () => { + mockRedisClient.isReady = false + await expect(invalidateAllGraphqlCachedQueries()).rejects.toThrow( + /not ready/i, + ) + }) +}) diff --git a/backend/middlewares/expressGraphqlCache.ts b/backend/middlewares/expressGraphqlCache.ts new file mode 100644 index 000000000..8e15273cd --- /dev/null +++ b/backend/middlewares/expressGraphqlCache.ts @@ -0,0 +1,302 @@ +import { createHash } from "crypto" + +import { NextFunction, Request, Response } from "express" +import { Logger } from "winston" + +import { GRAPHQL_ENDPOINT_PATH } from "../server" +import redisClient from "../services/redis" + +const CACHE_EXPIRE_TIME_SECONDS = 60 * 60 // 1 hour +const CACHE_PREFIX = "express-graphql-response-cache:" + +function normalizeQuery(query: string): string { + return query.replace(/\s+/g, " ").trim() +} + +function stableStringify(obj: unknown): string { + if (obj === null || typeof obj !== "object") return JSON.stringify(obj) + const seen = new WeakSet() + const sorter = (_key: string, value: any) => { + if (value && typeof value === "object") { + if (seen.has(value)) return + seen.add(value) + if (Array.isArray(value)) return value + return Object.keys(value) + .sort() + .reduce( + (acc, k) => { + acc[k] = value[k] + return acc + }, + {} as Record, + ) + } + return value + } + return JSON.stringify(obj, sorter) +} + +function isGraphQLQuery(body: any): boolean { + const src: string | undefined = body?.query + if (!src) return false + const q = src.trim().toLowerCase() + return q.startsWith("query") || q.startsWith("{") +} + +function buildCacheKey(req: Request): string { + const body = req.body ?? {} + const query = typeof body.query === "string" ? normalizeQuery(body.query) : "" + const variables = body.variables ? stableStringify(body.variables) : "" + const operationName = body.operationName ?? "" + + const payload = stableStringify({ query, variables, operationName }) + const hash = createHash("sha512").update(payload).digest("hex") + return `${CACHE_PREFIX}${hash}` +} + +// Detect GraphQL errors in a response payload +function hasGraphQLErrors(body: any): boolean { + try { + const parsed = + typeof body === "string" + ? JSON.parse(body) + : Buffer.isBuffer(body) + ? JSON.parse(body.toString("utf8")) + : body + + return Boolean( + parsed && Array.isArray(parsed.errors) && parsed.errors.length > 0, + ) + } catch { + // If it isn't valid JSON, treat as unknown -> don't cache + return true + } +} + +const createExpressGraphqlCacheMiddleware = (logger: Logger) => { + const expressCacheMiddleware = async ( + req: Request, + res: Response, + next: NextFunction, + ): Promise => { + // Only handle GraphQL endpoint POSTs + const isGraphqlPath = + req.originalUrl?.startsWith(GRAPHQL_ENDPOINT_PATH) || + `${req.baseUrl || ""}${req.path || ""}` === GRAPHQL_ENDPOINT_PATH + + if (!isGraphqlPath || req.method !== "POST") { + logger.debug("GraphQL cache: skip (not GraphQL POST request)", { + method: req.method, + path: req.originalUrl || req.path, + }) + return next() + } + + const client = redisClient() + if (!client?.isReady) { + logger.warn("GraphQL cache: skip (Redis client not ready)") + return next() + } + + // Skip if authenticated via Authorization header (per your setup) + if (req.headers.authorization !== undefined) { + logger.debug("GraphQL cache: skip (Authorization header present)") + return next() + } + + // Only cache queries + if (!isGraphQLQuery(req.body)) { + logger.debug("GraphQL cache: skip (not a GraphQL query)", { + bodyType: typeof req.body, + hasQuery: !!req.body?.query, + }) + return next() + } + + try { + const key = buildCacheKey(req) + const operationName = req.body?.operationName || "unnamed" + const cached = await client.get(key) + + if (cached) { + logger.info("GraphQL cache: HIT", { + key, + operationName, + cacheKey: key.substring(CACHE_PREFIX.length), + }) + res.status(200) + res.setHeader("Content-Type", "application/json; charset=utf-8") + res.setHeader("X-Cache", "HIT") + res.send(cached) + return + } + + logger.debug("GraphQL cache: MISS", { + key, + operationName, + cacheKey: key.substring(CACHE_PREFIX.length), + }) + + // Cache MISS: wrap send to store only successful, error-free JSON responses + const originalSend = res.send.bind(res) + + res.send = (body?: any): Response => { + try { + const status = res.statusCode + const is2xx = status >= 200 && status < 300 + const contentType = (res.getHeader("Content-Type") ?? "").toString() + const isJson = + contentType.includes("application/json") || typeof body === "object" + const hasErrors = hasGraphQLErrors(body) + + // Only cache if: + // - HTTP 2xx + // - JSON response + // - NO GraphQL errors + if (is2xx && isJson && client?.isReady && !hasErrors) { + const payload = + typeof body === "string" + ? body + : Buffer.isBuffer(body) + ? body.toString("utf8") + : JSON.stringify(body) + + client + .set(key, payload, { EX: CACHE_EXPIRE_TIME_SECONDS }) + .then(() => { + logger.info("GraphQL cache: STORED", { + key, + operationName, + cacheKey: key.substring(CACHE_PREFIX.length), + ttl: CACHE_EXPIRE_TIME_SECONDS, + payloadSize: payload.length, + }) + }) + .catch((e: any) => { + logger.error("GraphQL cache: failed to store", { + key, + operationName, + error: e instanceof Error ? e.message : String(e), + }) + }) + + res.setHeader("X-Cache", "MISS") + } else { + const reason = !is2xx + ? `status=${status}` + : !isJson + ? "not JSON" + : hasErrors + ? "GraphQL errors present" + : "unknown" + logger.debug("GraphQL cache: not stored", { + key, + operationName, + reason, + status, + isJson, + hasErrors, + }) + res.setHeader("X-Cache", "BYPASS") + } + } catch (e) { + logger.error("GraphQL cache: error during response handling", { + key, + error: e instanceof Error ? e.message : String(e), + stack: e instanceof Error ? e.stack : undefined, + }) + } + return originalSend(body) + } + + return next() + } catch (e) { + logger.error("GraphQL cache: middleware error", { + error: e instanceof Error ? e.message : String(e), + stack: e instanceof Error ? e.stack : undefined, + path: req.originalUrl || req.path, + }) + return next() + } + } + + return expressCacheMiddleware +} + +export default createExpressGraphqlCacheMiddleware + +/** + * Invalidates all cached GraphQL query responses created by this middleware. + * Returns the number of keys deleted. + */ +export async function invalidateAllGraphqlCachedQueries( + logger?: Logger, +): Promise { + const client = redisClient() + if (!client?.isReady) { + logger?.error("GraphQL cache: invalidation failed (Redis client not ready)") + throw new Error("Redis client is not ready") + } + + logger?.info("GraphQL cache: starting invalidation of all cached queries") + + let deleted = 0 + let failed = 0 + let scanIterations = 0 + + try { + let cursor = "0" + do { + scanIterations += 1 + const reply = await client.scan(cursor, { + MATCH: `${CACHE_PREFIX}*`, + COUNT: 1000, + }) + cursor = reply.cursor + const keys = reply.keys + logger?.debug("GraphQL cache: scan iteration", { + iteration: scanIterations, + keysFound: keys.length, + cursor, + }) + + for (const key of keys) { + try { + await client.del(key) + deleted += 1 + } catch (e) { + failed += 1 + logger?.warn("GraphQL cache: failed to delete key", { + key, + error: e instanceof Error ? e.message : String(e), + }) + } + } + } while (cursor !== "0") + + if (failed > 0) { + logger?.warn("GraphQL cache: invalidation completed with errors", { + deleted, + failed, + total: deleted + failed, + scanIterations, + }) + } else { + logger?.info("GraphQL cache: invalidation completed successfully", { + deleted, + scanIterations, + }) + } + } catch (e) { + logger?.error("GraphQL cache: invalidation error", { + error: e instanceof Error ? e.message : String(e), + stack: e instanceof Error ? e.stack : undefined, + deleted, + failed, + scanIterations, + }) + throw e + } + + return deleted +} diff --git a/backend/schema/Course/mutations.ts b/backend/schema/Course/mutations.ts index 2d90fb9ce..790c687be 100644 --- a/backend/schema/Course/mutations.ts +++ b/backend/schema/Course/mutations.ts @@ -8,6 +8,7 @@ import { Course, CourseSponsor, Prisma, StudyModule, Tag } from "@prisma/client" import { isAdmin } from "../../accessControl" import { Context } from "../../context" import { GraphQLUserInputError } from "../../lib/errors" +import { invalidateAllGraphqlCachedQueries } from "../../middlewares/expressGraphqlCache" import KafkaProducer, { ProducerMessage } from "../../services/kafkaProducer" import { invalidate } from "../../services/redis" import { emptyOrNullToUndefined, isDefined } from "../../util" @@ -149,6 +150,12 @@ export const CourseMutations = extendType({ await kafkaProducer.queueProducerMessage(producerMessage) await kafkaProducer.disconnect() + await invalidateAllGraphqlCachedQueries(ctx.logger).catch((e) => { + ctx.logger.warn( + `Failed to invalidate GraphQL cache after course creation: ${e}`, + ) + }) + return newCourse }, }) @@ -283,6 +290,12 @@ export const CourseMutations = extendType({ }, }) + await invalidateAllGraphqlCachedQueries(ctx.logger).catch((e) => { + ctx.logger.warn( + `Failed to invalidate GraphQL cache after course update: ${e}`, + ) + }) + return updatedCourse }, }) diff --git a/backend/server.ts b/backend/server.ts index 1b4538263..210b8f706 100644 --- a/backend/server.ts +++ b/backend/server.ts @@ -23,8 +23,11 @@ import { DEBUG, isProduction, isTest } from "./config" import { createDefaultData } from "./config/defaultData" import { ServerContext } from "./context" import { createLoaders } from "./loaders/createLoaders" +import createExpressGraphqlCacheMiddleware from "./middlewares/expressGraphqlCache" import { createSchema } from "./schema/common" +export const GRAPHQL_ENDPOINT_PATH = isProduction ? "/api" : "/" + // wrapped so that the context isn't cached between test instances const createExpressAppWithContext = ({ prisma, @@ -55,8 +58,10 @@ const addExpressMiddleware = async ( ) => { const { prisma, logger, knex, extraContext } = serverContext await createDefaultData(prisma) + // cache middleware first so that it's the first to run + app.use(createExpressGraphqlCacheMiddleware(logger)) app.use( - isProduction ? "/api" : "/", + GRAPHQL_ENDPOINT_PATH, expressMiddleware(apolloServer, { context: async (ctx) => { const loaders = createLoaders(prisma) diff --git a/backend/tests/index.ts b/backend/tests/index.ts index 1c7c8e0a8..52e4950a0 100644 --- a/backend/tests/index.ts +++ b/backend/tests/index.ts @@ -42,6 +42,7 @@ export const logger = { }, createLogger: jest.fn().mockImplementation(function () { return { + debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn(),