From 5355f5a398bf9fb8dd325b9265468c997f2918a5 Mon Sep 17 00:00:00 2001 From: James Broadhead Date: Wed, 15 Apr 2026 16:41:11 +0000 Subject: [PATCH] test: add 147 tests for service-context, stream-registry, genie connector, files plugin New test files covering major coverage gaps: - context/tests/service-context.test.ts (35 tests, 7% -> 100%) - stream/tests/stream-registry.test.ts (34 tests, 32% -> 100%) - connectors/genie/tests/client.test.ts (28 tests, 61% -> 97%) - plugins/files/tests/upload-and-write.test.ts (50 tests, 69% -> 89%) Total: 1566 -> 1713 tests, all passing. Co-authored-by: Isaac --- .../src/connectors/genie/tests/client.test.ts | 786 +++++++++++ .../src/context/tests/service-context.test.ts | 457 ++++++ .../files/tests/upload-and-write.test.ts | 1245 +++++++++++++++++ .../src/stream/tests/stream-registry.test.ts | 582 ++++++++ 4 files changed, 3070 insertions(+) create mode 100644 packages/appkit/src/connectors/genie/tests/client.test.ts create mode 100644 packages/appkit/src/context/tests/service-context.test.ts create mode 100644 packages/appkit/src/plugins/files/tests/upload-and-write.test.ts create mode 100644 packages/appkit/src/stream/tests/stream-registry.test.ts diff --git a/packages/appkit/src/connectors/genie/tests/client.test.ts b/packages/appkit/src/connectors/genie/tests/client.test.ts new file mode 100644 index 00000000..62fc3578 --- /dev/null +++ b/packages/appkit/src/connectors/genie/tests/client.test.ts @@ -0,0 +1,786 @@ +import type { GenieMessage } from "@databricks/sdk-experimental/dist/apis/dashboards"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { GenieConnector } from "../client"; +import type { GenieStreamEvent } from "../types"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +async function collect( + gen: AsyncGenerator, +): Promise { + const events: GenieStreamEvent[] = []; + for await (const event of gen) { + events.push(event); + } + return events; +} + +function makeGenieMessage(overrides: Partial = {}): GenieMessage { + return { + message_id: "msg-1", + conversation_id: "conv-1", + space_id: "space-1", + status: "COMPLETED", + content: "Hello from Genie", + attachments: [], + ...overrides, + } as GenieMessage; +} + +function makeGenieMessageWithQuery( + overrides: Partial = {}, +): GenieMessage { + return makeGenieMessage({ + attachments: [ + { + attachment_id: "att-1", + query: { + title: "Sales Query", + description: "Total sales", + query: "SELECT sum(amount) FROM sales", + statement_id: "stmt-1", + }, + }, + ], + ...overrides, + }); +} + +/** Creates a mock WorkspaceClient with genie methods stubbed. */ +function createMockWorkspaceClient() { + return { + genie: { + startConversation: vi.fn(), + createMessage: vi.fn(), + getMessage: vi.fn(), + listConversationMessages: vi.fn(), + getMessageAttachmentQueryResult: vi.fn(), + }, + } as any; +} + +/** + * Builds a mock waiter whose `.wait()` invokes `onProgress` for each + * progress value, then resolves with the final result. + */ +function createMockWaiter(opts: { + progressValues?: Partial[]; + result: GenieMessage; +}) { + return { + wait: vi.fn().mockImplementation(async (options: any = {}) => { + if (opts.progressValues) { + for (const value of opts.progressValues) { + if (options.onProgress) { + await options.onProgress(value); + } + } + } + return opts.result; + }), + message_id: opts.result.message_id, + conversation_id: opts.result.conversation_id, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("GenieConnector", () => { + let connector: GenieConnector; + let ws: ReturnType; + + beforeEach(() => { + connector = new GenieConnector({ timeout: 0 }); + ws = createMockWorkspaceClient(); + }); + + // ----------------------------------------------------------------------- + // streamSendMessage + // ----------------------------------------------------------------------- + + describe("streamSendMessage", () => { + test("yields message_start, status updates, then message_result", async () => { + const completedMsg = makeGenieMessage(); + const waiter = createMockWaiter({ + progressValues: [ + { status: "EXECUTING_QUERY" }, + { status: "COMPLETED" }, + ], + result: completedMsg, + }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage( + ws, + "space-1", + "What are sales?", + undefined, + ), + ); + + expect(events[0]).toEqual({ + type: "message_start", + conversationId: "conv-1", + messageId: "msg-1", + spaceId: "space-1", + }); + + const statusEvents = events.filter((e) => e.type === "status"); + expect(statusEvents).toEqual([ + { type: "status", status: "EXECUTING_QUERY" }, + { type: "status", status: "COMPLETED" }, + ]); + + const msgResult = events.find((e) => e.type === "message_result"); + expect(msgResult).toBeDefined(); + expect((msgResult as any).message.messageId).toBe("msg-1"); + }); + + test("new conversation calls startConversation", async () => { + const completedMsg = makeGenieMessage(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + await collect( + connector.streamSendMessage(ws, "space-1", "hello", undefined), + ); + + expect(ws.genie.startConversation).toHaveBeenCalledWith({ + space_id: "space-1", + content: "hello", + }); + expect(ws.genie.createMessage).not.toHaveBeenCalled(); + }); + + test("existing conversation calls createMessage", async () => { + const completedMsg = makeGenieMessage(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.createMessage.mockResolvedValue(waiter); + + await collect( + connector.streamSendMessage(ws, "space-1", "hello", "conv-existing"), + ); + + expect(ws.genie.createMessage).toHaveBeenCalledWith({ + space_id: "space-1", + conversation_id: "conv-existing", + content: "hello", + }); + expect(ws.genie.startConversation).not.toHaveBeenCalled(); + }); + + test("emits query_result for attachments with statementIds", async () => { + const completedMsg = makeGenieMessageWithQuery(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const statementResponse = { + manifest: { + schema: { columns: [{ name: "total", type_name: "DOUBLE" }] }, + }, + result: { data_array: [["1234.56"]] }, + }; + ws.genie.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: statementResponse, + }); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "query", undefined), + ); + + const queryResult = events.find((e) => e.type === "query_result"); + expect(queryResult).toEqual({ + type: "query_result", + attachmentId: "att-1", + statementId: "stmt-1", + data: statementResponse, + }); + }); + + test("yields error event on SDK failure", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("Network timeout"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hello", undefined), + ); + + expect(events).toEqual([{ type: "error", error: "Network timeout" }]); + }); + + test("classifies RESOURCE_DOES_NOT_EXIST as access denied", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("RESOURCE_DOES_NOT_EXIST: space not found"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hello", undefined), + ); + + expect(events).toEqual([ + { + type: "error", + error: "You don't have access to this Genie Space.", + }, + ]); + }); + + test("emits error event when query result fetch fails", async () => { + const completedMsg = makeGenieMessageWithQuery(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + ws.genie.getMessageAttachmentQueryResult.mockRejectedValue( + new Error("statement expired"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "query", undefined), + ); + + const errorEvent = events.find((e) => e.type === "error"); + expect(errorEvent).toEqual({ + type: "error", + error: "Failed to fetch query result for attachment att-1", + }); + }); + }); + + // ----------------------------------------------------------------------- + // streamConversation + // ----------------------------------------------------------------------- + + describe("streamConversation", () => { + test("yields message_result for each message, then history_info", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [ + makeGenieMessage({ message_id: "m1", content: "first" }), + makeGenieMessage({ message_id: "m2", content: "second" }), + ], + next_page_token: null, + }); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: false, + }), + ); + + const messageResults = events.filter((e) => e.type === "message_result"); + expect(messageResults).toHaveLength(2); + + const historyInfo = events.find((e) => e.type === "history_info"); + expect(historyInfo).toEqual({ + type: "history_info", + conversationId: "conv-1", + spaceId: "space-1", + nextPageToken: null, + loadedCount: 2, + }); + }); + + test("fetches query results in parallel when includeQueryResults=true", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [ + makeGenieMessageWithQuery({ + message_id: "m1", + attachments: [ + { + attachment_id: "att-a", + query: { + title: "Q1", + query: "SELECT 1", + statement_id: "stmt-a", + }, + }, + { + attachment_id: "att-b", + query: { + title: "Q2", + query: "SELECT 2", + statement_id: "stmt-b", + }, + }, + ], + }), + ], + next_page_token: null, + }); + + const stmtResponse = { + manifest: { schema: { columns: [] } }, + result: { data_array: [] }, + }; + ws.genie.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: stmtResponse, + }); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: true, + }), + ); + + const queryResults = events.filter((e) => e.type === "query_result"); + expect(queryResults).toHaveLength(2); + expect(ws.genie.getMessageAttachmentQueryResult).toHaveBeenCalledTimes(2); + }); + + test("skips query results when includeQueryResults=false", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [makeGenieMessageWithQuery()], + next_page_token: null, + }); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: false, + }), + ); + + expect(events.filter((e) => e.type === "query_result")).toHaveLength(0); + expect(ws.genie.getMessageAttachmentQueryResult).not.toHaveBeenCalled(); + }); + + test("handles partial query result failures via Promise.allSettled", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [ + makeGenieMessage({ + message_id: "m1", + attachments: [ + { + attachment_id: "att-ok", + query: { + title: "OK", + query: "SELECT 1", + statement_id: "stmt-ok", + }, + }, + { + attachment_id: "att-fail", + query: { + title: "Fail", + query: "SELECT 2", + statement_id: "stmt-fail", + }, + }, + ], + }), + ], + next_page_token: null, + }); + + const stmtResponse = { + manifest: { schema: { columns: [] } }, + result: { data_array: [] }, + }; + + ws.genie.getMessageAttachmentQueryResult + .mockResolvedValueOnce({ statement_response: stmtResponse }) + .mockRejectedValueOnce(new Error("statement expired")); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: true, + }), + ); + + const queryResults = events.filter((e) => e.type === "query_result"); + expect(queryResults).toHaveLength(1); + + const errors = events.filter((e) => e.type === "error"); + expect(errors).toHaveLength(1); + expect((errors[0] as any).error).toBe("statement expired"); + }); + + test("yields error when listConversationMessages fails", async () => { + ws.genie.listConversationMessages.mockRejectedValue( + new Error("RESOURCE_DOES_NOT_EXIST: conv not found"), + ); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1"), + ); + + expect(events).toEqual([ + { + type: "error", + error: "You don't have access to this Genie Space.", + }, + ]); + }); + }); + + // ----------------------------------------------------------------------- + // streamGetMessage + // ----------------------------------------------------------------------- + + describe("streamGetMessage", () => { + test("polls until COMPLETED, yields status + message_result", async () => { + ws.genie.getMessage + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce(makeGenieMessage({ status: "COMPLETED" })); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + expect(events[0]).toEqual({ + type: "status", + status: "EXECUTING_QUERY", + }); + expect(events[1]).toEqual({ type: "status", status: "COMPLETED" }); + expect(events[2]).toMatchObject({ type: "message_result" }); + expect(ws.genie.getMessage).toHaveBeenCalledTimes(2); + }); + + test("polls until FAILED, yields status + message_result", async () => { + ws.genie.getMessage + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce( + makeGenieMessage({ + status: "FAILED", + error: { error: "query timed out" }, + }), + ); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + const statusEvents = events.filter((e) => e.type === "status"); + expect(statusEvents).toEqual([ + { type: "status", status: "EXECUTING_QUERY" }, + { type: "status", status: "FAILED" }, + ]); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.status).toBe("FAILED"); + expect(msgResult.message.error).toBe("query timed out"); + }); + + test("respects abort signal", async () => { + const controller = new AbortController(); + + ws.genie.getMessage.mockResolvedValue( + makeGenieMessage({ status: "EXECUTING_QUERY" }), + ); + + const gen = connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 50, + signal: controller.signal, + }); + + const events: GenieStreamEvent[] = []; + // Collect the first status event, then abort + for await (const event of gen) { + events.push(event); + if (events.length === 1) { + controller.abort(); + } + } + + // Should have stopped after abort - at most 2 events + // (the status from poll 1, and possibly status from poll 2 that was already in-flight) + expect(events.length).toBeLessThanOrEqual(2); + expect(events[0]).toEqual({ + type: "status", + status: "EXECUTING_QUERY", + }); + }); + + test("yields error when getMessage throws", async () => { + ws.genie.getMessage.mockRejectedValue(new Error("service unavailable")); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + expect(events).toEqual([{ type: "error", error: "service unavailable" }]); + }); + + test("does not duplicate status events for same status", async () => { + ws.genie.getMessage + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce(makeGenieMessage({ status: "COMPLETED" })); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + const statusEvents = events.filter((e) => e.type === "status"); + expect(statusEvents).toEqual([ + { type: "status", status: "EXECUTING_QUERY" }, + { type: "status", status: "COMPLETED" }, + ]); + }); + }); + + // ----------------------------------------------------------------------- + // sendMessage + // ----------------------------------------------------------------------- + + describe("sendMessage", () => { + test("returns completed message response", async () => { + const completedMsg = makeGenieMessage({ + message_id: "msg-42", + conversation_id: "conv-new", + }); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const result = await connector.sendMessage( + ws, + "space-1", + "What are sales?", + undefined, + ); + + expect(result.messageId).toBe("msg-42"); + expect(result.conversationId).toBe("conv-new"); + expect(result.status).toBe("COMPLETED"); + }); + }); + + // ----------------------------------------------------------------------- + // getConversation + // ----------------------------------------------------------------------- + + describe("getConversation", () => { + test("paginates through all pages", async () => { + // listConversationMessages reverses the SDK response, so mock data + // is ordered newest-first (as the SDK returns) and results are + // oldest-first after reversal. + ws.genie.listConversationMessages + .mockResolvedValueOnce({ + messages: [ + makeGenieMessage({ message_id: "m2" }), + makeGenieMessage({ message_id: "m1" }), + ], + next_page_token: "page2", + }) + .mockResolvedValueOnce({ + messages: [makeGenieMessage({ message_id: "m3" })], + next_page_token: null, + }); + + const result = await connector.getConversation(ws, "space-1", "conv-1"); + + expect(result.messages).toHaveLength(3); + expect(result.messages.map((m) => m.messageId)).toEqual([ + "m1", + "m2", + "m3", + ]); + expect(ws.genie.listConversationMessages).toHaveBeenCalledTimes(2); + }); + + test("respects maxMessages limit", async () => { + const smallConnector = new GenieConnector({ + timeout: 0, + maxMessages: 2, + }); + + ws.genie.listConversationMessages.mockResolvedValueOnce({ + messages: [ + makeGenieMessage({ message_id: "m1" }), + makeGenieMessage({ message_id: "m2" }), + makeGenieMessage({ message_id: "m3" }), + ], + next_page_token: "page2", + }); + + const result = await smallConnector.getConversation( + ws, + "space-1", + "conv-1", + ); + + // Should be sliced to maxMessages + expect(result.messages).toHaveLength(2); + // Should NOT fetch a second page since length already >= maxMessages + expect(ws.genie.listConversationMessages).toHaveBeenCalledTimes(1); + }); + }); + + // ----------------------------------------------------------------------- + // mapAttachments (tested indirectly via toMessageResponse) + // ----------------------------------------------------------------------- + + describe("mapAttachments", () => { + test("handles query attachments", async () => { + const msg = makeGenieMessageWithQuery(); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + // We drive through streamSendMessage to exercise mapAttachments + ws.genie.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: { + manifest: { schema: { columns: [] } }, + result: { data_array: [] }, + }, + }); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments[0]).toEqual({ + attachmentId: "att-1", + query: { + title: "Sales Query", + description: "Total sales", + query: "SELECT sum(amount) FROM sales", + statementId: "stmt-1", + }, + text: undefined, + suggestedQuestions: undefined, + }); + }); + + test("handles text attachments", async () => { + const msg = makeGenieMessage({ + attachments: [ + { + attachment_id: "att-text", + text: { content: "Here is the explanation" }, + }, + ], + }); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments[0]).toEqual({ + attachmentId: "att-text", + query: undefined, + text: { content: "Here is the explanation" }, + suggestedQuestions: undefined, + }); + }); + + test("handles suggestedQuestions attachments", async () => { + const msg = makeGenieMessage({ + attachments: [ + { + attachment_id: "att-sq", + suggested_questions: { + questions: ["What is X?", "Show me Y"], + }, + }, + ], + }); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments[0]).toEqual({ + attachmentId: "att-sq", + query: undefined, + text: undefined, + suggestedQuestions: ["What is X?", "Show me Y"], + }); + }); + + test("returns empty array when message has no attachments", async () => { + const msg = makeGenieMessage({ attachments: undefined }); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments).toEqual([]); + }); + }); + + // ----------------------------------------------------------------------- + // classifyGenieError (tested indirectly via error events) + // ----------------------------------------------------------------------- + + describe("classifyGenieError", () => { + test("maps RESOURCE_DOES_NOT_EXIST to space access denied", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("RESOURCE_DOES_NOT_EXIST: space xyz"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: "You don't have access to this Genie Space.", + }); + }); + + test("maps failed-to-reach-COMPLETED + FAILED to table permissions", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("failed to reach COMPLETED state, got FAILED"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: + "You may not have access to the data tables. Please verify your table permissions.", + }); + }); + + test("passes through unknown error messages", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("something unexpected"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: "something unexpected", + }); + }); + + test("handles non-Error throwable", async () => { + ws.genie.startConversation.mockRejectedValue("string error"); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: "string error", + }); + }); + }); +}); diff --git a/packages/appkit/src/context/tests/service-context.test.ts b/packages/appkit/src/context/tests/service-context.test.ts new file mode 100644 index 00000000..e8610da1 --- /dev/null +++ b/packages/appkit/src/context/tests/service-context.test.ts @@ -0,0 +1,457 @@ +import { setupDatabricksEnv } from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + AuthenticationError, + ConfigurationError, + InitializationError, +} from "../../errors"; +import { ServiceContext } from "../service-context"; + +// ── Mock @databricks/sdk-experimental ────────────────────────────── + +const { mockMe, mockApiRequest, MockWorkspaceClient } = vi.hoisted(() => { + const mockMe = vi.fn(); + const mockApiRequest = vi.fn(); + + const MockWorkspaceClient = vi.fn().mockImplementation(() => ({ + currentUser: { me: mockMe }, + apiClient: { request: mockApiRequest }, + })); + + return { mockMe, mockApiRequest, MockWorkspaceClient }; +}); + +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: MockWorkspaceClient, +})); + +// ── Helpers ──────────────────────────────────────────────────────── + +function setupDefaultMocks() { + mockMe.mockResolvedValue({ id: "service-user-123" }); + mockApiRequest.mockResolvedValue({ "x-databricks-org-id": "ws-456" }); +} + +// ── Tests ────────────────────────────────────────────────────────── + +describe("ServiceContext", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + vi.clearAllMocks(); + ServiceContext.reset(); + setupDatabricksEnv(); + setupDefaultMocks(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + ServiceContext.reset(); + }); + + // ── initialize() ─────────────────────────────────────────────── + + describe("initialize()", () => { + test("should initialize with a pre-configured client", async () => { + const client = new MockWorkspaceClient() as any; + + const state = await ServiceContext.initialize({}, client); + + expect(state.client).toBe(client); + expect(state.serviceUserId).toBe("service-user-123"); + expect(await state.workspaceId).toBe("ws-456"); + }); + + test("should create a WorkspaceClient when none is provided", async () => { + await ServiceContext.initialize(); + + // The mock constructor is called once internally + expect(MockWorkspaceClient).toHaveBeenCalled(); + }); + + test("should resolve warehouseId when options.warehouseId is true", async () => { + process.env.DATABRICKS_WAREHOUSE_ID = "wh-789"; + + const state = await ServiceContext.initialize({ warehouseId: true }); + + expect(state.warehouseId).toBeDefined(); + expect(await state.warehouseId).toBe("wh-789"); + }); + + test("should not set warehouseId when options.warehouseId is false", async () => { + const state = await ServiceContext.initialize({ warehouseId: false }); + + expect(state.warehouseId).toBeUndefined(); + }); + + test("should not set warehouseId when options are omitted", async () => { + const state = await ServiceContext.initialize(); + + expect(state.warehouseId).toBeUndefined(); + }); + + test("should throw when currentUser.me() returns no id", async () => { + mockMe.mockResolvedValue({}); + + await expect(ServiceContext.initialize()).rejects.toThrow( + ConfigurationError, + ); + }); + + test("should be idempotent - calling twice returns same instance", async () => { + const state1 = await ServiceContext.initialize(); + const state2 = await ServiceContext.initialize(); + + expect(state1).toBe(state2); + }); + + test("concurrent calls return the same promise", async () => { + const p1 = ServiceContext.initialize(); + const p2 = ServiceContext.initialize(); + + const [state1, state2] = await Promise.all([p1, p2]); + + expect(state1).toBe(state2); + // currentUser.me should only be called once regardless of concurrent calls + expect(mockMe).toHaveBeenCalledTimes(1); + }); + }); + + // ── get() ────────────────────────────────────────────────────── + + describe("get()", () => { + test("should throw InitializationError when not initialized", () => { + expect(() => ServiceContext.get()).toThrow(InitializationError); + expect(() => ServiceContext.get()).toThrow( + /ServiceContext not initialized/, + ); + }); + + test("should return state after initialization", async () => { + const state = await ServiceContext.initialize(); + const retrieved = ServiceContext.get(); + + expect(retrieved).toBe(state); + }); + }); + + // ── isInitialized() ──────────────────────────────────────────── + + describe("isInitialized()", () => { + test("should return false before initialization", () => { + expect(ServiceContext.isInitialized()).toBe(false); + }); + + test("should return true after initialization", async () => { + await ServiceContext.initialize(); + + expect(ServiceContext.isInitialized()).toBe(true); + }); + + test("should return false after reset()", async () => { + await ServiceContext.initialize(); + ServiceContext.reset(); + + expect(ServiceContext.isInitialized()).toBe(false); + }); + }); + + // ── createUserContext() ──────────────────────────────────────── + + describe("createUserContext()", () => { + beforeEach(async () => { + await ServiceContext.initialize({ warehouseId: true }); + }); + + test("should create a user context with correct properties", () => { + const userCtx = ServiceContext.createUserContext( + "user-token-abc", + "user-42", + "Alice", + ); + + expect(userCtx.userId).toBe("user-42"); + expect(userCtx.userName).toBe("Alice"); + expect(userCtx.isUserContext).toBe(true); + expect(userCtx.client).toBeDefined(); + }); + + test("should share warehouseId and workspaceId from service context", async () => { + process.env.DATABRICKS_WAREHOUSE_ID = "wh-shared"; + + // Re-initialize with the new env + ServiceContext.reset(); + mockApiRequest.mockResolvedValue({ "x-databricks-org-id": "ws-shared" }); + await ServiceContext.initialize({ warehouseId: true }); + + const userCtx = ServiceContext.createUserContext("user-token", "user-1"); + + const serviceCtx = ServiceContext.get(); + expect(userCtx.warehouseId).toBe(serviceCtx.warehouseId); + expect(userCtx.workspaceId).toBe(serviceCtx.workspaceId); + }); + + test("should create user client with PAT authType", () => { + ServiceContext.createUserContext("user-token", "user-1"); + + // The last call to MockWorkspaceClient should be for the user client + const lastCall = + MockWorkspaceClient.mock.calls[ + MockWorkspaceClient.mock.calls.length - 1 + ]; + expect(lastCall[0]).toMatchObject({ + token: "user-token", + host: process.env.DATABRICKS_HOST, + authType: "pat", + }); + }); + + test("should handle missing userName gracefully", () => { + const userCtx = ServiceContext.createUserContext("user-token", "user-1"); + + expect(userCtx.userName).toBeUndefined(); + }); + + test("should throw AuthenticationError on missing token", () => { + expect(() => ServiceContext.createUserContext("", "user-1")).toThrow( + AuthenticationError, + ); + }); + + test("should throw ConfigurationError when DATABRICKS_HOST is not set", () => { + delete process.env.DATABRICKS_HOST; + + expect(() => ServiceContext.createUserContext("token", "user-1")).toThrow( + ConfigurationError, + ); + }); + + test("should throw InitializationError when service context is not initialized", () => { + ServiceContext.reset(); + + expect(() => ServiceContext.createUserContext("token", "user-1")).toThrow( + InitializationError, + ); + }); + }); + + // ── reset() ──────────────────────────────────────────────────── + + describe("reset()", () => { + test("should clear the singleton state", async () => { + await ServiceContext.initialize(); + expect(ServiceContext.isInitialized()).toBe(true); + + ServiceContext.reset(); + + expect(ServiceContext.isInitialized()).toBe(false); + expect(() => ServiceContext.get()).toThrow(InitializationError); + }); + + test("should allow re-initialization after reset", async () => { + await ServiceContext.initialize(); + ServiceContext.reset(); + + mockMe.mockResolvedValue({ id: "new-service-user" }); + const state = await ServiceContext.initialize(); + + expect(state.serviceUserId).toBe("new-service-user"); + }); + }); + + // ── getWorkspaceId() (private, tested via initialize) ───────── + + describe("getWorkspaceId()", () => { + test("should use DATABRICKS_WORKSPACE_ID env var when set", async () => { + process.env.DATABRICKS_WORKSPACE_ID = "env-ws-123"; + + const state = await ServiceContext.initialize(); + + expect(await state.workspaceId).toBe("env-ws-123"); + // Should not call the SCIM API when env var is set + expect(mockApiRequest).not.toHaveBeenCalledWith( + expect.objectContaining({ path: "/api/2.0/preview/scim/v2/Me" }), + ); + }); + + test("should call SCIM API when env var is not set", async () => { + delete process.env.DATABRICKS_WORKSPACE_ID; + mockApiRequest.mockResolvedValue({ + "x-databricks-org-id": "scim-ws-789", + }); + + const state = await ServiceContext.initialize(); + + expect(await state.workspaceId).toBe("scim-ws-789"); + expect(mockApiRequest).toHaveBeenCalledWith( + expect.objectContaining({ + path: "/api/2.0/preview/scim/v2/Me", + method: "GET", + responseHeaders: ["x-databricks-org-id"], + }), + ); + }); + + test("should throw when SCIM API returns no workspace ID", async () => { + delete process.env.DATABRICKS_WORKSPACE_ID; + mockApiRequest.mockResolvedValue({}); + + const state = await ServiceContext.initialize(); + + await expect(state.workspaceId).rejects.toThrow(ConfigurationError); + }); + }); + + // ── getWarehouseId() (private, tested via initialize) ───────── + + describe("getWarehouseId()", () => { + test("should use DATABRICKS_WAREHOUSE_ID env var when set", async () => { + process.env.DATABRICKS_WAREHOUSE_ID = "env-wh-abc"; + + const state = await ServiceContext.initialize({ warehouseId: true }); + + expect(await state.warehouseId).toBe("env-wh-abc"); + }); + + test("should auto-discover warehouse in development mode", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [ + { id: "wh-stopped", state: "STOPPED" }, + { id: "wh-running", state: "RUNNING" }, + { id: "wh-starting", state: "STARTING" }, + ], + }); + } + // SCIM response for workspaceId + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + // Should pick RUNNING warehouse (highest priority) + expect(await state.warehouseId).toBe("wh-running"); + }); + + test("should sort warehouses by state priority in dev mode", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [ + { id: "wh-stopping", state: "STOPPING" }, + { id: "wh-starting", state: "STARTING" }, + { id: "wh-stopped", state: "STOPPED" }, + ], + }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + // STOPPED (priority 1) < STARTING (priority 2) < STOPPING (priority 3) + expect(await state.warehouseId).toBe("wh-stopped"); + }); + + test("should throw in dev mode when no warehouses are available", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ warehouses: [] }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + }); + + test("should throw in dev mode when all warehouses are deleted", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [ + { id: "wh-deleted", state: "DELETED" }, + { id: "wh-deleting", state: "DELETING" }, + ], + }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + }); + + test("should throw in dev mode when best warehouse has no id", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [{ state: "RUNNING" }], + }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + }); + + test("should throw in production when DATABRICKS_WAREHOUSE_ID is not set", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "production"; + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + await expect(state.warehouseId).rejects.toThrow( + /DATABRICKS_WAREHOUSE_ID/, + ); + }); + }); + + // ── getClientOptions() ───────────────────────────────────────── + + describe("getClientOptions()", () => { + test("should return product name and version", () => { + const options = ServiceContext.getClientOptions(); + + expect(options.product).toBe("@databricks/appkit"); + expect(options.productVersion).toBeDefined(); + }); + + test("should include dev mode user agent extra in development", () => { + process.env.NODE_ENV = "development"; + + const options = ServiceContext.getClientOptions(); + + expect(options.userAgentExtra).toEqual({ mode: "dev" }); + }); + + test("should not include dev mode user agent extra in production", () => { + process.env.NODE_ENV = "production"; + + const options = ServiceContext.getClientOptions(); + + expect(options.userAgentExtra).toBeUndefined(); + }); + }); +}); diff --git a/packages/appkit/src/plugins/files/tests/upload-and-write.test.ts b/packages/appkit/src/plugins/files/tests/upload-and-write.test.ts new file mode 100644 index 00000000..8da3f021 --- /dev/null +++ b/packages/appkit/src/plugins/files/tests/upload-and-write.test.ts @@ -0,0 +1,1245 @@ +import { Readable } from "node:stream"; +import { mockServiceContext, setupDatabricksEnv } from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ServiceContext } from "../../../context/service-context"; +import { AuthenticationError } from "../../../errors"; +import { FilesPlugin } from "../plugin"; + +const { mockClient, MockApiError, mockCacheInstance } = vi.hoisted(() => { + const mockFilesApi = { + listDirectoryContents: vi.fn(), + download: vi.fn(), + getMetadata: vi.fn(), + upload: vi.fn(), + createDirectory: vi.fn(), + delete: vi.fn(), + }; + + const mockClient = { + files: mockFilesApi, + config: { + host: "https://test.databricks.com", + authenticate: vi.fn(), + }, + }; + + class MockApiError extends Error { + statusCode: number; + constructor(message: string, statusCode: number) { + super(message); + this.name = "ApiError"; + this.statusCode = statusCode; + } + } + + const mockCacheInstance = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_key: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn((...args: unknown[]) => JSON.stringify(args)), + }; + + return { mockFilesApi, mockClient, MockApiError, mockCacheInstance }; +}); + +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn(() => mockClient), + ApiError: MockApiError, +})); + +vi.mock("../../../context", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + getWorkspaceClient: vi.fn(() => mockClient), + isInUserContext: vi.fn(() => true), + }; +}); + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => mockCacheInstance), + }, +})); + +const VOLUMES_CONFIG = { + volumes: { + uploads: { maxUploadSize: 100_000_000 }, + exports: {}, + }, +}; + +/** + * Helper to get a route handler from the plugin. Registers routes on a mock + * router and returns the handler matching the given method + path suffix. + */ +function getRouteHandler( + plugin: FilesPlugin, + method: "get" | "post" | "delete", + pathSuffix: string, +) { + const mockRouter = { + use: vi.fn(), + get: vi.fn(), + post: vi.fn(), + put: vi.fn(), + delete: vi.fn(), + patch: vi.fn(), + } as any; + + plugin.injectRoutes(mockRouter); + + const call = mockRouter[method].mock.calls.find( + (c: unknown[]) => + typeof c[0] === "string" && (c[0] as string).endsWith(pathSuffix), + ); + if (!call) throw new Error(`No route found for ${method} ...${pathSuffix}`); + return call[call.length - 1] as (req: any, res: any) => Promise; +} + +/** + * Creates a mock Express response with all methods needed by the route handlers. + */ +function mockRes() { + const res: any = { + headersSent: false, + }; + res.status = vi.fn().mockReturnValue(res); + res.json = vi.fn().mockReturnValue(res); + res.type = vi.fn().mockReturnValue(res); + res.send = vi.fn().mockReturnValue(res); + res.setHeader = vi.fn().mockReturnValue(res); + res.write = vi.fn().mockReturnValue(true); + res.destroy = vi.fn(); + res.end = vi.fn(); + res.on = vi.fn().mockReturnValue(res); + res.once = vi.fn().mockReturnValue(res); + res.emit = vi.fn().mockReturnValue(true); + res.removeListener = vi.fn().mockReturnValue(res); + res.pipe = vi.fn().mockReturnValue(res); + return res; +} + +/** + * Creates a mock Express request with the auth headers needed by the plugin's + * `asUser()` proxy. + */ +function mockReq(volumeKey: string, overrides: Record = {}): any { + const headers: Record = { + "x-forwarded-access-token": "test-token", + "x-forwarded-user": "test-user", + ...(overrides.headers ?? {}), + }; + + const req: any = { + params: { volumeKey }, + query: {}, + ...overrides, + headers, + header: (name: string) => headers[name.toLowerCase()], + }; + + return req; +} + +/** + * Creates a mock Express request that behaves as a Node Readable stream, + * suitable for the upload handler which calls Readable.toWeb(req). + */ +function mockUploadReq( + volumeKey: string, + bodyChunks: Buffer[], + overrides: Record = {}, +): any { + const headers: Record = { + "x-forwarded-access-token": "test-token", + "x-forwarded-user": "test-user", + ...(overrides.headers ?? {}), + }; + + // Create a real Node Readable so Readable.toWeb() works + let chunkIndex = 0; + const stream = new Readable({ + read() { + if (chunkIndex < bodyChunks.length) { + this.push(bodyChunks[chunkIndex++]); + } else { + this.push(null); + } + }, + }); + + // Patch stream with Express request properties + (stream as any).params = { volumeKey }; + (stream as any).query = overrides.query ?? {}; + (stream as any).headers = headers; + (stream as any).header = (name: string) => headers[name.toLowerCase()]; + (stream as any).body = overrides.body; + + return stream; +} + +describe("FilesPlugin - Upload, Write, and Error Handling", () => { + let serviceContextMock: Awaited>; + + beforeEach(async () => { + vi.clearAllMocks(); + setupDatabricksEnv(); + ServiceContext.reset(); + process.env.DATABRICKS_VOLUME_UPLOADS = "/Volumes/catalog/schema/uploads"; + process.env.DATABRICKS_VOLUME_EXPORTS = "/Volumes/catalog/schema/exports"; + serviceContextMock = await mockServiceContext(); + }); + + afterEach(() => { + serviceContextMock?.restore(); + delete process.env.DATABRICKS_VOLUME_UPLOADS; + delete process.env.DATABRICKS_VOLUME_EXPORTS; + }); + + // ────────────────────────────────────────────────────────────────────── + // 1. _handleApiError: AuthenticationError -> 401, ApiError variants, + // non-ApiError -> 500 + // ────────────────────────────────────────────────────────────────────── + describe("_handleApiError", () => { + test("AuthenticationError returns 401 with error message", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new AuthenticationError("Missing token"), + "fallback msg", + ); + + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith({ + error: "Missing token", + plugin: "files", + }); + }); + + test("ApiError with 4xx status preserves status and message", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new MockApiError("Forbidden", 403), + "fallback msg", + ); + + expect(res.status).toHaveBeenCalledWith(403); + expect(res.json).toHaveBeenCalledWith({ + error: "Forbidden", + statusCode: 403, + plugin: "files", + }); + }); + + test("ApiError with 404 preserves status", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new MockApiError("Not found", 404), + "fallback msg", + ); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Not found", + statusCode: 404, + plugin: "files", + }); + }); + + test("ApiError with 409 Conflict preserves status", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new MockApiError("Conflict", 409), + "fallback msg", + ); + + expect(res.status).toHaveBeenCalledWith(409); + expect(res.json).toHaveBeenCalledWith({ + error: "Conflict", + statusCode: 409, + plugin: "files", + }); + }); + + test("ApiError with 5xx returns 500 with fallback message", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new MockApiError("Bad Gateway", 502), + "Operation failed", + ); + + expect(res.status).toHaveBeenCalledWith(500); + expect(res.json).toHaveBeenCalledWith({ + error: "Operation failed", + plugin: "files", + }); + }); + + test("ApiError with statusCode 500 returns 500 with fallback", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new MockApiError("Internal error", 500), + "Fallback", + ); + + expect(res.status).toHaveBeenCalledWith(500); + expect(res.json).toHaveBeenCalledWith({ + error: "Fallback", + plugin: "files", + }); + }); + + test("non-ApiError falls back to 500 with fallback message", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError(res, new Error("unknown"), "Fallback"); + + expect(res.status).toHaveBeenCalledWith(500); + expect(res.json).toHaveBeenCalledWith({ + error: "Fallback", + plugin: "files", + }); + }); + + test("non-ApiError exception returns 500 with fallback message", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._handleApiError( + res, + new TypeError("Cannot read properties of undefined"), + "Internal Server Error", + ); + + expect(res.status).toHaveBeenCalledWith(500); + expect(res.json).toHaveBeenCalledWith({ + error: "Internal Server Error", + plugin: "files", + }); + }); + + test("AuthenticationError via route (missing token in production)", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/list"); + const res = mockRes(); + + const originalEnv = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + + try { + await handler( + { + params: { volumeKey: "uploads" }, + query: {}, + headers: {}, + header: () => undefined, + }, + res, + ); + + expect(res.status).toHaveBeenCalledWith(401); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringContaining("token"), + plugin: "files", + }), + ); + } finally { + process.env.NODE_ENV = originalEnv; + } + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 2. Upload path: TransformStream size enforcement during streaming + // ────────────────────────────────────────────────────────────────────── + describe("Upload stream mid-transfer size enforcement", () => { + test("upload exceeding size mid-stream is caught by execute and returns error", async () => { + const plugin = new FilesPlugin({ + volumes: { + uploads: { maxUploadSize: 50 }, + }, + }); + const handler = getRouteHandler(plugin, "post", "/upload"); + const res = mockRes(); + + // Two chunks: 30 + 30 = 60 > maxSize of 50 + const req = mockUploadReq( + "uploads", + [Buffer.alloc(30), Buffer.alloc(30)], + { + query: { path: "/Volumes/catalog/schema/uploads/file.bin" }, + // No content-length header so the pre-check does not catch it + }, + ); + + // Spy on the connector's upload to consume the stream (the + // TransformStream size limiter fires when chunks are read). + const connector = (plugin as any).volumeConnectors.uploads; + vi.spyOn(connector, "upload").mockImplementation( + async (_client: any, _path: string, contents: any) => { + const reader = (contents as ReadableStream).getReader(); + while (true) { + const { done } = await reader.read(); + if (done) break; + } + }, + ); + + await handler(req, res); + + // The stream size error is caught by execute() and returned as + // {ok: false, status: 500}. The Content-Length pre-check (tested + // separately) catches oversized uploads before streaming starts. + const statusCalls = res.status.mock.calls.flat(); + expect(statusCalls).toContain(500); + }); + + test("outer catch returns 413 for stream size error escaping execute", async () => { + // The outer catch in _handleUpload has a specific check for the + // "exceeds maximum allowed size" message. This tests that path by + // making execute() re-throw instead of catching. + const plugin = new FilesPlugin({ + volumes: { + uploads: { maxUploadSize: 50 }, + }, + }); + const handler = getRouteHandler(plugin, "post", "/upload"); + const res = mockRes(); + + const req = mockUploadReq("uploads", [Buffer.from("data")], { + query: { path: "/Volumes/catalog/schema/uploads/file.bin" }, + }); + + // Override trackWrite to throw the size error directly + vi.spyOn(plugin as any, "trackWrite").mockRejectedValue( + new Error("Upload stream exceeds maximum allowed size (50 bytes)"), + ); + + await handler(req, res); + + expect(res.status).toHaveBeenCalledWith(413); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringContaining("exceeds maximum allowed size"), + plugin: "files", + }), + ); + }); + + test("upload within size limit succeeds", async () => { + const plugin = new FilesPlugin({ + volumes: { + uploads: { maxUploadSize: 100 }, + }, + }); + const handler = getRouteHandler(plugin, "post", "/upload"); + const res = mockRes(); + + const req = mockUploadReq( + "uploads", + [Buffer.from("small file content")], + { + query: { path: "/Volumes/catalog/schema/uploads/small.txt" }, + }, + ); + + const connector = (plugin as any).volumeConnectors.uploads; + vi.spyOn(connector, "upload").mockImplementation( + async (_client: any, _path: string, contents: any) => { + const reader = (contents as ReadableStream).getReader(); + while (true) { + const { done } = await reader.read(); + if (done) break; + } + }, + ); + + await handler(req, res); + + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ success: true }), + ); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 3. Upload: cache invalidation after successful upload + // ────────────────────────────────────────────────────────────────────── + describe("Upload cache invalidation", () => { + test("successful upload calls cache.delete for parent directory", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "post", "/upload"); + const res = mockRes(); + + const req = mockUploadReq("uploads", [Buffer.from("file content")], { + query: { path: "/Volumes/catalog/schema/uploads/dir/file.txt" }, + }); + + const connector = (plugin as any).volumeConnectors.uploads; + vi.spyOn(connector, "upload").mockImplementation( + async (_client: any, _path: string, contents: any) => { + const reader = (contents as ReadableStream).getReader(); + while (true) { + const { done } = await reader.read(); + if (done) break; + } + }, + ); + + await handler(req, res); + + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ success: true }), + ); + // _invalidateListCache should call generateKey and then delete + expect(mockCacheInstance.generateKey).toHaveBeenCalled(); + expect(mockCacheInstance.delete).toHaveBeenCalled(); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 4. Raw endpoint: CSP sandbox header and safe vs unsafe content type + // ────────────────────────────────────────────────────────────────────── + describe("Raw endpoint security headers", () => { + function makeStreamResponse(content: string) { + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode(content)); + controller.close(); + }, + }); + return { contents: stream }; + } + + test("raw endpoint sets CSP sandbox header", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/raw"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue(makeStreamResponse("data")); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/data.json" }, + }), + res, + ); + + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Security-Policy", + "sandbox", + ); + }); + + test("raw endpoint with safe content type (image/png) does not set Content-Disposition", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/raw"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue( + makeStreamResponse("PNG data"), + ); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/image.png" }, + }), + res, + ); + + expect(res.setHeader).toHaveBeenCalledWith("Content-Type", "image/png"); + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Security-Policy", + "sandbox", + ); + + // Content-Disposition should NOT be set for safe inline types + const dispositionCalls = res.setHeader.mock.calls.filter( + (c: string[]) => c[0] === "Content-Disposition", + ); + expect(dispositionCalls).toHaveLength(0); + }); + + test("raw endpoint with unsafe content type (text/html) forces download", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/raw"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue( + makeStreamResponse(""), + ); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/page.html" }, + }), + res, + ); + + expect(res.setHeader).toHaveBeenCalledWith("Content-Type", "text/html"); + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Security-Policy", + "sandbox", + ); + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Disposition", + 'attachment; filename="page.html"', + ); + }); + + test("raw endpoint with SVG forces download", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/raw"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue( + makeStreamResponse(""), + ); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/icon.svg" }, + }), + res, + ); + + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Disposition", + 'attachment; filename="icon.svg"', + ); + }); + + test("raw endpoint sets X-Content-Type-Options: nosniff", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/raw"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue( + makeStreamResponse("content"), + ); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/file.txt" }, + }), + res, + ); + + expect(res.setHeader).toHaveBeenCalledWith( + "X-Content-Type-Options", + "nosniff", + ); + }); + + test("raw endpoint with missing path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/raw"); + const res = mockRes(); + + await handler(mockReq("uploads", { query: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path is required", + plugin: "files", + }), + ); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 5. Download endpoint: Content-Disposition with sanitized filename + // ────────────────────────────────────────────────────────────────────── + describe("Download endpoint Content-Disposition", () => { + function makeStreamResponse(content: string) { + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode(content)); + controller.close(); + }, + }); + return { contents: stream }; + } + + test("download sets Content-Disposition: attachment with filename", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/download"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue( + makeStreamResponse("file data"), + ); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/report.pdf" }, + }), + res, + ); + + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Disposition", + 'attachment; filename="report.pdf"', + ); + }); + + test("download sanitizes filename with special characters", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/download"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue(makeStreamResponse("data")); + + await handler( + mockReq("uploads", { + query: { path: '/Volumes/catalog/schema/uploads/my "file".txt' }, + }), + res, + ); + + // Quotes in filenames should be escaped + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Disposition", + 'attachment; filename="my \\"file\\".txt"', + ); + }); + + test("download always sets Content-Disposition even for safe types", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/download"); + const res = mockRes(); + + mockClient.files.download.mockResolvedValue(makeStreamResponse("{}")); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/data.json" }, + }), + res, + ); + + // Download mode always forces attachment, even for safe types + expect(res.setHeader).toHaveBeenCalledWith( + "Content-Disposition", + 'attachment; filename="data.json"', + ); + }); + + test("download with missing path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/download"); + const res = mockRes(); + + await handler(mockReq("uploads", { query: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path is required", + plugin: "files", + }), + ); + }); + + test("download with response having no contents calls res.end()", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/download"); + const res = mockRes(); + + // Response with no contents field (empty file) + mockClient.files.download.mockResolvedValue({}); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/empty.txt" }, + }), + res, + ); + + expect(res.end).toHaveBeenCalled(); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 6. Delete endpoint: cache invalidation + // ────────────────────────────────────────────────────────────────────── + describe("Delete cache invalidation", () => { + test("successful delete invalidates list cache", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "delete", ""); + const res = mockRes(); + + mockClient.files.delete.mockResolvedValue(undefined); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/dir/file.txt" }, + }), + res, + ); + + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ success: true }), + ); + expect(mockCacheInstance.generateKey).toHaveBeenCalled(); + expect(mockCacheInstance.delete).toHaveBeenCalled(); + }); + + test("delete without path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "delete", ""); + const res = mockRes(); + + await handler(mockReq("uploads", { query: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ error: "path is required" }), + ); + }); + + test("delete that throws ApiError returns proper status", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "delete", ""); + const res = mockRes(); + + mockClient.files.delete.mockRejectedValue( + new MockApiError("Not found", 404), + ); + + await handler( + mockReq("uploads", { + query: { path: "/Volumes/catalog/schema/uploads/missing.txt" }, + }), + res, + ); + + // SDK errors go through execute() which returns {ok: false, status: 404} + // then _sendStatusError is called with STATUS_CODES[404] = "Not Found" + expect(res.status).toHaveBeenCalledWith(404); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 7. Mkdir endpoint: cache invalidation + // ────────────────────────────────────────────────────────────────────── + describe("Mkdir cache invalidation", () => { + test("successful mkdir invalidates list cache", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "post", "/mkdir"); + const res = mockRes(); + + mockClient.files.createDirectory.mockResolvedValue(undefined); + + await handler( + mockReq("uploads", { + body: { path: "/Volumes/catalog/schema/uploads/newdir" }, + }), + res, + ); + + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ success: true }), + ); + expect(mockCacheInstance.generateKey).toHaveBeenCalled(); + expect(mockCacheInstance.delete).toHaveBeenCalled(); + }); + + test("mkdir without path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "post", "/mkdir"); + const res = mockRes(); + + await handler(mockReq("uploads", { body: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ error: "path is required" }), + ); + }); + + test("mkdir that throws ApiError 409 is handled via execute", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "post", "/mkdir"); + const res = mockRes(); + + mockClient.files.createDirectory.mockRejectedValue( + new MockApiError("Conflict", 409), + ); + + await handler( + mockReq("uploads", { + body: { path: "/Volumes/catalog/schema/uploads/existing" }, + }), + res, + ); + + // SDK errors go through execute() -> _sendStatusError with status 409 + expect(res.status).toHaveBeenCalledWith(409); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 8. Shutdown: trackWrite waits for in-flight writes, deadline timeout + // ────────────────────────────────────────────────────────────────────── + describe("Shutdown and trackWrite", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + test("shutdown waits for in-flight writes to complete", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + + // Simulate an in-flight write + (plugin as any).inflightWrites = 1; + + const shutdownPromise = plugin.shutdown(); + + // After 500ms the shutdown loop should still be waiting + await vi.advanceTimersByTimeAsync(500); + + // Simulate the write completing + (plugin as any).inflightWrites = 0; + + await vi.advanceTimersByTimeAsync(500); + await shutdownPromise; + + // Shutdown should have completed + expect((plugin as any).inflightWrites).toBe(0); + }); + + test("shutdown times out after 10 seconds with pending writes", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const abortAllSpy = vi.spyOn((plugin as any).streamManager, "abortAll"); + + // Simulate an in-flight write that never completes + (plugin as any).inflightWrites = 2; + + const shutdownPromise = plugin.shutdown(); + + // Advance past the 10-second deadline + await vi.advanceTimersByTimeAsync(11_000); + await shutdownPromise; + + // Should still call abortAll even after timeout + expect(abortAllSpy).toHaveBeenCalled(); + // inflightWrites remains > 0 since the writes never completed + expect((plugin as any).inflightWrites).toBe(2); + }); + + test("shutdown completes immediately when no in-flight writes", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const abortAllSpy = vi.spyOn((plugin as any).streamManager, "abortAll"); + + (plugin as any).inflightWrites = 0; + + const shutdownPromise = plugin.shutdown(); + await vi.advanceTimersByTimeAsync(0); + await shutdownPromise; + + expect(abortAllSpy).toHaveBeenCalled(); + }); + + test("trackWrite increments and decrements inflightWrites correctly", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + expect((plugin as any).inflightWrites).toBe(0); + + let resolveInner!: (value: string) => void; + const innerPromise = new Promise((r) => { + resolveInner = r; + }); + + const trackPromise = (plugin as any).trackWrite(() => innerPromise); + + // While the tracked fn is running, inflightWrites should be 1 + expect((plugin as any).inflightWrites).toBe(1); + + resolveInner("done"); + const result = await trackPromise; + + expect(result).toBe("done"); + expect((plugin as any).inflightWrites).toBe(0); + }); + + test("trackWrite decrements inflightWrites even on rejection", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + + const trackPromise = (plugin as any).trackWrite(() => + Promise.reject(new Error("write failed")), + ); + + await expect(trackPromise).rejects.toThrow("write failed"); + expect((plugin as any).inflightWrites).toBe(0); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 9. Volume discovery: merging explicit config with env vars + // ────────────────────────────────────────────────────────────────────── + describe("Volume discovery merging", () => { + test("explicit config takes priority over env vars", () => { + const volumes = FilesPlugin.discoverVolumes({ + volumes: { + uploads: { maxUploadSize: 42 }, + custom: { maxUploadSize: 99 }, + }, + }); + + // uploads: explicit config wins (maxUploadSize: 42), not {} from env + expect(volumes.uploads).toEqual({ maxUploadSize: 42 }); + // exports: discovered from env with default empty config + expect(volumes.exports).toEqual({}); + // custom: explicit only, no env var + expect(volumes.custom).toEqual({ maxUploadSize: 99 }); + }); + + test("discovered volumes get empty config objects", () => { + process.env.DATABRICKS_VOLUME_DATA = "/Volumes/catalog/schema/data"; + + try { + const volumes = FilesPlugin.discoverVolumes({}); + expect(volumes.data).toEqual({}); + } finally { + delete process.env.DATABRICKS_VOLUME_DATA; + } + }); + + test("explicit volumes without env vars still appear", () => { + delete process.env.DATABRICKS_VOLUME_UPLOADS; + delete process.env.DATABRICKS_VOLUME_EXPORTS; + + const volumes = FilesPlugin.discoverVolumes({ + volumes: { + private: { maxUploadSize: 10 }, + }, + }); + + expect(Object.keys(volumes)).toEqual(["private"]); + expect(volumes.private).toEqual({ maxUploadSize: 10 }); + }); + + test("env var volume is not added when explicit config has the same key", () => { + process.env.DATABRICKS_VOLUME_SPECIAL = "/Volumes/catalog/schema/special"; + + try { + const volumes = FilesPlugin.discoverVolumes({ + volumes: { + special: { maxUploadSize: 500 }, + }, + }); + + // Explicit wins; should not be overwritten with {} + expect(volumes.special).toEqual({ maxUploadSize: 500 }); + } finally { + delete process.env.DATABRICKS_VOLUME_SPECIAL; + } + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 10. Path validation edge cases + // ────────────────────────────────────────────────────────────────────── + describe("Path validation", () => { + test("path with null bytes returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/read"); + const res = mockRes(); + + await handler( + mockReq("uploads", { query: { path: "/Volumes/test/\0evil" } }), + res, + ); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path must not contain null bytes", + }), + ); + }); + + test("path exceeding 4096 characters returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/read"); + const res = mockRes(); + + const longPath = "/Volumes/test/" + "a".repeat(4100); + + await handler(mockReq("uploads", { query: { path: longPath } }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.stringContaining("exceeds maximum length"), + }), + ); + }); + + test("exists without path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/exists"); + const res = mockRes(); + + await handler(mockReq("uploads", { query: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path is required", + plugin: "files", + }), + ); + }); + + test("metadata without path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/metadata"); + const res = mockRes(); + + await handler(mockReq("uploads", { query: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path is required", + plugin: "files", + }), + ); + }); + + test("preview without path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "get", "/preview"); + const res = mockRes(); + + await handler(mockReq("uploads", { query: {} }), res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path is required", + plugin: "files", + }), + ); + }); + + test("upload without path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "post", "/upload"); + const res = mockRes(); + + const req = mockUploadReq("uploads", [Buffer.from("data")], { + query: {}, + }); + + await handler(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path is required", + plugin: "files", + }), + ); + }); + + test("delete with null bytes in path returns 400", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const handler = getRouteHandler(plugin, "delete", ""); + const res = mockRes(); + + await handler( + mockReq("uploads", { query: { path: "/Volumes/test/\0evil" } }), + res, + ); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: "path must not contain null bytes", + }), + ); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 11. clientConfig returns volume keys + // ────────────────────────────────────────────────────────────────────── + describe("clientConfig", () => { + test("returns configured volume keys", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const config = plugin.clientConfig(); + + expect(config).toEqual({ volumes: ["uploads", "exports"] }); + }); + + test("returns empty volumes when none configured and no env vars", () => { + delete process.env.DATABRICKS_VOLUME_UPLOADS; + delete process.env.DATABRICKS_VOLUME_EXPORTS; + + const plugin = new FilesPlugin({ volumes: {} }); + const config = plugin.clientConfig(); + + expect(config).toEqual({ volumes: [] }); + }); + }); + + // ────────────────────────────────────────────────────────────────────── + // 12. _sendStatusError uses HTTP status code text + // ────────────────────────────────────────────────────────────────────── + describe("_sendStatusError", () => { + test("sends standard HTTP status text for known codes", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._sendStatusError(res, 404); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Not Found", + plugin: "files", + }); + }); + + test("sends 'Unknown Error' for non-standard status codes", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const res = mockRes(); + + (plugin as any)._sendStatusError(res, 999); + + expect(res.status).toHaveBeenCalledWith(999); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown Error", + plugin: "files", + }); + }); + }); +}); diff --git a/packages/appkit/src/stream/tests/stream-registry.test.ts b/packages/appkit/src/stream/tests/stream-registry.test.ts new file mode 100644 index 00000000..d3f70e95 --- /dev/null +++ b/packages/appkit/src/stream/tests/stream-registry.test.ts @@ -0,0 +1,582 @@ +import type { Context } from "@opentelemetry/api"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { EventRingBuffer } from "../buffers"; +import { StreamRegistry } from "../stream-registry"; +import type { StreamEntry } from "../types"; +import { SSEErrorCode } from "../types"; + +/** Create a minimal mock StreamEntry for testing. */ +function createMockStreamEntry( + streamId: string, + overrides: Partial = {}, +): StreamEntry { + return { + streamId, + generator: (async function* () {})(), + eventBuffer: new EventRingBuffer(10), + clients: new Set(), + isCompleted: false, + lastAccess: Date.now(), + abortController: new AbortController(), + traceContext: {} as Context, + ...overrides, + }; +} + +/** Create a mock response object that mimics express.Response for SSE writes. */ +function createMockClient(writableEnded = false) { + return { + write: vi.fn().mockReturnValue(true), + writableEnded, + } as unknown as import("express").Response; +} + +describe("StreamRegistry", () => { + let registry: StreamRegistry; + + beforeEach(() => { + registry = new StreamRegistry(3); + }); + + describe("add and get", () => { + test("should add a stream and retrieve it by id", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + + const result = registry.get("stream-1"); + + expect(result).toBe(entry); + }); + + test("should return null for a non-existent stream", () => { + const result = registry.get("non-existent"); + + expect(result).toBeNull(); + }); + + test("should add multiple streams and retrieve each", () => { + const entry1 = createMockStreamEntry("stream-1"); + const entry2 = createMockStreamEntry("stream-2"); + const entry3 = createMockStreamEntry("stream-3"); + + registry.add(entry1); + registry.add(entry2); + registry.add(entry3); + + expect(registry.get("stream-1")).toBe(entry1); + expect(registry.get("stream-2")).toBe(entry2); + expect(registry.get("stream-3")).toBe(entry3); + }); + }); + + describe("has", () => { + test("should return true for an existing stream", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + + expect(registry.has("stream-1")).toBe(true); + }); + + test("should return false for a non-existent stream", () => { + expect(registry.has("non-existent")).toBe(false); + }); + + test("should return false after a stream is removed", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + registry.remove("stream-1"); + + expect(registry.has("stream-1")).toBe(false); + }); + }); + + describe("remove", () => { + test("should remove an existing stream", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + + registry.remove("stream-1"); + + expect(registry.get("stream-1")).toBeNull(); + expect(registry.size()).toBe(0); + }); + + test("should not throw when removing a non-existent stream", () => { + expect(() => registry.remove("non-existent")).not.toThrow(); + }); + + test("should only remove the specified stream", () => { + const entry1 = createMockStreamEntry("stream-1"); + const entry2 = createMockStreamEntry("stream-2"); + registry.add(entry1); + registry.add(entry2); + + registry.remove("stream-1"); + + expect(registry.get("stream-1")).toBeNull(); + expect(registry.get("stream-2")).toBe(entry2); + expect(registry.size()).toBe(1); + }); + }); + + describe("size", () => { + test("should return 0 for an empty registry", () => { + expect(registry.size()).toBe(0); + }); + + test("should track size as streams are added", () => { + registry.add(createMockStreamEntry("stream-1")); + expect(registry.size()).toBe(1); + + registry.add(createMockStreamEntry("stream-2")); + expect(registry.size()).toBe(2); + + registry.add(createMockStreamEntry("stream-3")); + expect(registry.size()).toBe(3); + }); + + test("should decrease when streams are removed", () => { + registry.add(createMockStreamEntry("stream-1")); + registry.add(createMockStreamEntry("stream-2")); + + registry.remove("stream-1"); + + expect(registry.size()).toBe(1); + }); + + test("should not exceed capacity after eviction", () => { + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Adding a fourth stream to a capacity-3 registry triggers eviction + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(registry.size()).toBe(3); + }); + }); + + describe("capacity enforcement and eviction", () => { + test("should evict the oldest stream when at capacity", () => { + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Adding a fourth should evict stream-1 (oldest lastAccess=100) + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-2")).toBe(true); + expect(registry.has("stream-3")).toBe(true); + expect(registry.has("stream-4")).toBe(true); + }); + + test("should evict the stream with the smallest lastAccess and abort it", () => { + // When lastAccess order matches insertion order, the eviction logic + // cleanly targets the LRU stream. The stream with the smallest + // lastAccess is found and aborted. + const ac1 = new AbortController(); + const ac2 = new AbortController(); + const ac3 = new AbortController(); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: ac1, + }), + ); + registry.add( + createMockStreamEntry("stream-2", { + lastAccess: 300, + abortController: ac2, + }), + ); + registry.add( + createMockStreamEntry("stream-3", { + lastAccess: 200, + abortController: ac3, + }), + ); + + // Adding stream-4 triggers eviction. stream-1 has the smallest + // lastAccess (100) so it should be targeted. + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(ac1.signal.aborted).toBe(true); + expect(ac2.signal.aborted).toBe(false); + expect(ac3.signal.aborted).toBe(false); + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-4")).toBe(true); + }); + + test("should exclude the stream being added from eviction", () => { + // This tests the excludeStreamId parameter: if a stream with the same + // ID as the one being added already exists and is the oldest, it should + // still be excluded from eviction. In practice, the new stream won't be + // in the registry yet when eviction runs, so excludeStreamId prevents + // misidentification. + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Add stream with id "stream-1" again; eviction should skip "stream-1" + // even though stream-1 has the oldest lastAccess, because it's the + // excludeStreamId. stream-2 should be evicted instead. + registry.add(createMockStreamEntry("stream-1", { lastAccess: 400 })); + + // stream-1 is updated (RingBuffer updates existing keys in place) + expect(registry.has("stream-1")).toBe(true); + // stream-2 should have been evicted as it was the oldest non-excluded + expect(registry.has("stream-2")).toBe(false); + expect(registry.has("stream-3")).toBe(true); + }); + + test("should abort the evicted stream's AbortController", () => { + const abortController1 = new AbortController(); + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: abortController1, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(abortController1.signal.aborted).toBe(true); + }); + + test("should abort with 'Stream evicted' reason", () => { + const abortController1 = new AbortController(); + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: abortController1, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(abortController1.signal.reason).toBe("Stream evicted"); + }); + }); + + describe("eviction SSE broadcast", () => { + test("should send STREAM_EVICTED error to all clients of evicted stream", () => { + const client1 = createMockClient(); + const client2 = createMockClient(); + + const clients = new Set([client1, client2]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Trigger eviction of stream-1 + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Each client should have received the SSE error event + for (const client of [client1, client2]) { + expect(client.write).toHaveBeenCalledWith("event: error\n"); + expect(client.write).toHaveBeenCalledWith( + `data: ${JSON.stringify({ error: "Stream evicted", code: SSEErrorCode.STREAM_EVICTED })}\n\n`, + ); + } + }); + + test("should skip clients with writableEnded=true during eviction broadcast", () => { + const activeClient = createMockClient(false); + const endedClient = createMockClient(true); + + const clients = new Set([activeClient, endedClient]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Active client should receive the error + expect(activeClient.write).toHaveBeenCalledWith("event: error\n"); + + // Ended client should NOT receive any writes + expect(endedClient.write).not.toHaveBeenCalled(); + }); + + test("should handle client.write throwing an error gracefully", () => { + const throwingClient = createMockClient(false); + (throwingClient.write as ReturnType).mockImplementation( + () => { + throw new Error("Connection reset"); + }, + ); + + const normalClient = createMockClient(false); + + const clients = new Set([throwingClient, normalClient]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Should not throw despite the throwing client + expect(() => { + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + }).not.toThrow(); + + // The normal client should still receive the error despite the other + // client throwing. Note: both clients are in a Set, iteration order is + // insertion order. The throwing client's error is caught per-client. + // We verify the abort still happened (the overall eviction completed). + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-4")).toBe(true); + }); + + test("should send correct SSE error format with STREAM_EVICTED code", () => { + const client = createMockClient(); + const clients = new Set([client]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Verify the exact data payload + const dataCall = ( + client.write as ReturnType + ).mock.calls.find((call: unknown[]) => + (call[0] as string).startsWith("data:"), + ); + expect(dataCall).toBeDefined(); + + const payload = JSON.parse( + (dataCall![0] as string).replace("data: ", "").trim(), + ); + expect(payload).toEqual({ + error: "Stream evicted", + code: "STREAM_EVICTED", + }); + }); + + test("should broadcast to multiple clients on the same evicted stream", () => { + const client1 = createMockClient(); + const client2 = createMockClient(); + const client3 = createMockClient(); + + const clients = new Set([client1, client2, client3]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // All three clients should have received exactly 2 write calls each + // (one for "event: error\n" and one for the data line) + for (const client of [client1, client2, client3]) { + expect(client.write).toHaveBeenCalledTimes(2); + } + }); + + test("should not broadcast if evicted stream has no clients", () => { + const abortController = new AbortController(); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients: new Set(), + abortController, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Should not throw even with no clients + expect(() => { + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + }).not.toThrow(); + + // Stream should still be evicted and aborted + expect(registry.has("stream-1")).toBe(false); + expect(abortController.signal.aborted).toBe(true); + }); + }); + + describe("clear", () => { + test("should abort all streams and clear the registry", () => { + const ac1 = new AbortController(); + const ac2 = new AbortController(); + const ac3 = new AbortController(); + + registry.add(createMockStreamEntry("stream-1", { abortController: ac1 })); + registry.add(createMockStreamEntry("stream-2", { abortController: ac2 })); + registry.add(createMockStreamEntry("stream-3", { abortController: ac3 })); + + registry.clear(); + + expect(registry.size()).toBe(0); + expect(ac1.signal.aborted).toBe(true); + expect(ac2.signal.aborted).toBe(true); + expect(ac3.signal.aborted).toBe(true); + }); + + test("should abort with 'Server shutdown' reason", () => { + const ac = new AbortController(); + registry.add(createMockStreamEntry("stream-1", { abortController: ac })); + + registry.clear(); + + expect(ac.signal.reason).toBe("Server shutdown"); + }); + + test("should handle clearing an empty registry", () => { + expect(() => registry.clear()).not.toThrow(); + expect(registry.size()).toBe(0); + }); + + test("should make all streams inaccessible after clear", () => { + registry.add(createMockStreamEntry("stream-1")); + registry.add(createMockStreamEntry("stream-2")); + + registry.clear(); + + expect(registry.get("stream-1")).toBeNull(); + expect(registry.get("stream-2")).toBeNull(); + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-2")).toBe(false); + }); + + test("should allow adding new streams after clear", () => { + registry.add(createMockStreamEntry("stream-1")); + registry.clear(); + + const newEntry = createMockStreamEntry("stream-new"); + registry.add(newEntry); + + expect(registry.get("stream-new")).toBe(newEntry); + expect(registry.size()).toBe(1); + }); + }); + + describe("edge cases", () => { + test("should work with capacity of 1", () => { + const smallRegistry = new StreamRegistry(1); + const ac1 = new AbortController(); + + smallRegistry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: ac1, + }), + ); + expect(smallRegistry.size()).toBe(1); + + smallRegistry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + + expect(smallRegistry.size()).toBe(1); + expect(smallRegistry.has("stream-1")).toBe(false); + expect(smallRegistry.has("stream-2")).toBe(true); + expect(ac1.signal.aborted).toBe(true); + }); + + test("should handle adding a stream with the same id (update)", () => { + const entry1 = createMockStreamEntry("stream-1", { + lastAccess: 100, + }); + const entry2 = createMockStreamEntry("stream-1", { + lastAccess: 200, + }); + + registry.add(entry1); + registry.add(entry2); + + // The RingBuffer updates in place for same key + expect(registry.size()).toBe(1); + const retrieved = registry.get("stream-1"); + expect(retrieved?.lastAccess).toBe(200); + }); + + test("should handle sequential evictions correctly", () => { + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // First eviction: stream-1 evicted + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + expect(registry.has("stream-1")).toBe(false); + + // Second eviction: stream-2 evicted + registry.add(createMockStreamEntry("stream-5", { lastAccess: 500 })); + expect(registry.has("stream-2")).toBe(false); + + // stream-3, stream-4, stream-5 remain + expect(registry.has("stream-3")).toBe(true); + expect(registry.has("stream-4")).toBe(true); + expect(registry.has("stream-5")).toBe(true); + expect(registry.size()).toBe(3); + }); + + test("should not evict when under capacity", () => { + const ac1 = new AbortController(); + registry.add(createMockStreamEntry("stream-1", { abortController: ac1 })); + registry.add(createMockStreamEntry("stream-2")); + + // Only 2 streams in a capacity-3 registry, no eviction + expect(registry.size()).toBe(2); + expect(ac1.signal.aborted).toBe(false); + }); + + test("should handle mixed writable states during eviction", () => { + const activeClient = createMockClient(false); + const endedClient1 = createMockClient(true); + const endedClient2 = createMockClient(true); + + const clients = new Set([endedClient1, activeClient, endedClient2]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Only the active client should receive writes + expect(activeClient.write).toHaveBeenCalledTimes(2); + expect(endedClient1.write).not.toHaveBeenCalled(); + expect(endedClient2.write).not.toHaveBeenCalled(); + }); + }); +});