diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index a4666a49..6c6c6f5b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -7,11 +7,20 @@ // Types from shared export type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, BasePluginConfig, CacheConfig, IAppRouter, + Message, PluginData, StreamExecutionSettings, + Thread, + ThreadStore, + ToolProvider, } from "shared"; export { isSQLTypeMarker, sql } from "shared"; export { CacheManager } from "./cache"; @@ -54,6 +63,21 @@ export { toPlugin, } from "./plugin"; export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export { + type FunctionTool, + type HostedTool, + isFunctionTool, + isHostedTool, + mcpServer, + type ToolConfig, + tool, +} from "./plugins/agents/tools"; +export { + type AgentTool, + isToolkitEntry, + type ToolkitEntry, + type ToolkitOptions, +} from "./plugins/agents/types"; // Files plugin types (for custom policy authoring) export type { FileAction, diff --git a/packages/appkit/src/plugins/agents/build-toolkit.ts b/packages/appkit/src/plugins/agents/build-toolkit.ts new file mode 100644 index 00000000..0140425d --- /dev/null +++ b/packages/appkit/src/plugins/agents/build-toolkit.ts @@ -0,0 +1,63 @@ +import type { AgentToolDefinition } from "shared"; +import type { ToolRegistry } from "./tools/define-tool"; +import { toToolJSONSchema } from "./tools/json-schema"; +import type { ToolkitEntry, ToolkitOptions } from "./types"; + +/** + * Converts a plugin's internal `ToolRegistry` into a keyed record of + * `ToolkitEntry` markers suitable for spreading into an `AgentDefinition.tools` + * record. + * + * The `opts` record controls shape and filtering: + * - `prefix` — overrides the default `${pluginName}.` prefix; `""` drops it. + * - `only` — allowlist of local tool names to include (post-prefix). + * - `except` — denylist of local names. + * - `rename` — per-tool key remapping (applied after prefix/filter). + * + * Each entry carries `pluginName` + `localName` so the agents plugin can + * dispatch back through `PluginContext.executeTool` for OBO + telemetry. + */ +export function buildToolkitEntries( + pluginName: string, + registry: ToolRegistry, + opts: ToolkitOptions = {}, +): Record { + const prefix = opts.prefix ?? `${pluginName}.`; + const only = opts.only ? new Set(opts.only) : null; + const except = opts.except ? new Set(opts.except) : null; + const rename = opts.rename ?? {}; + + const out: Record = {}; + + for (const [localName, entry] of Object.entries(registry)) { + if (only && !only.has(localName)) continue; + if (except?.has(localName)) continue; + + const keyAfterPrefix = `${prefix}${localName}`; + const key = rename[localName] ?? keyAfterPrefix; + + const parameters = toToolJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + + const def: AgentToolDefinition = { + name: key, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + + out[key] = { + __toolkitRef: true, + pluginName, + localName, + def, + annotations: entry.annotations, + autoInheritable: entry.autoInheritable, + }; + } + + return out; +} diff --git a/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts b/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts new file mode 100644 index 00000000..08f71da9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/build-toolkit.test.ts @@ -0,0 +1,101 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { buildToolkitEntries } from "../build-toolkit"; +import { defineTool, type ToolRegistry } from "../tools/define-tool"; +import { isToolkitEntry } from "../types"; + +const registry: ToolRegistry = { + query: defineTool({ + description: "Run a query", + schema: z.object({ sql: z.string() }), + handler: () => "ok", + }), + history: defineTool({ + description: "Get query history", + schema: z.object({}), + handler: () => [], + }), +}; + +describe("buildToolkitEntries", () => { + test("produces ToolkitEntry per registry item with default dotted prefix", () => { + const entries = buildToolkitEntries("analytics", registry); + expect(Object.keys(entries).sort()).toEqual([ + "analytics.history", + "analytics.query", + ]); + for (const entry of Object.values(entries)) { + expect(isToolkitEntry(entry)).toBe(true); + expect(entry.pluginName).toBe("analytics"); + } + }); + + test("respects prefix option (empty drops the namespace)", () => { + const entries = buildToolkitEntries("analytics", registry, { prefix: "" }); + expect(Object.keys(entries).sort()).toEqual(["history", "query"]); + }); + + test("respects custom prefix", () => { + const entries = buildToolkitEntries("analytics", registry, { + prefix: "db.", + }); + expect(Object.keys(entries).sort()).toEqual(["db.history", "db.query"]); + }); + + test("only filter keeps the listed local names", () => { + const entries = buildToolkitEntries("analytics", registry, { + only: ["query"], + }); + expect(Object.keys(entries)).toEqual(["analytics.query"]); + }); + + test("except filter drops the listed local names", () => { + const entries = buildToolkitEntries("analytics", registry, { + except: ["history"], + }); + expect(Object.keys(entries)).toEqual(["analytics.query"]); + }); + + test("rename remaps specific local names (overrides the prefix key)", () => { + const entries = buildToolkitEntries("analytics", registry, { + rename: { query: "sql" }, + }); + expect(Object.keys(entries).sort()).toEqual(["analytics.history", "sql"]); + }); + + test("exposes the original plugin+local name so dispatch can route", () => { + const entries = buildToolkitEntries("analytics", registry, { + prefix: "db.", + }); + const qEntry = entries["db.query"]; + expect(qEntry.pluginName).toBe("analytics"); + expect(qEntry.localName).toBe("query"); + expect(qEntry.def.name).toBe("db.query"); + }); + + test("propagates autoInheritable from the source registry", () => { + const mixed: ToolRegistry = { + readIt: defineTool({ + description: "safe read", + schema: z.object({}), + autoInheritable: true, + handler: () => "ok", + }), + writeIt: defineTool({ + description: "unsafe write", + schema: z.object({}), + autoInheritable: false, + handler: () => "ok", + }), + unmarked: defineTool({ + description: "default: not auto-inheritable", + schema: z.object({}), + handler: () => "ok", + }), + }; + const entries = buildToolkitEntries("p", mixed); + expect(entries["p.readIt"].autoInheritable).toBe(true); + expect(entries["p.writeIt"].autoInheritable).toBe(false); + expect(entries["p.unmarked"].autoInheritable).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/define-tool.test.ts b/packages/appkit/src/plugins/agents/tests/define-tool.test.ts new file mode 100644 index 00000000..ef61e8c4 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/define-tool.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../tools/define-tool"; + +describe("defineTool()", () => { + test("returns an entry matching the input config", () => { + const entry = defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + annotations: { readOnly: true }, + handler: ({ msg }) => msg, + }); + + expect(entry.description).toBe("echo"); + expect(entry.annotations).toEqual({ readOnly: true }); + expect(typeof entry.handler).toBe("function"); + }); +}); + +describe("executeFromRegistry", () => { + const registry: ToolRegistry = { + echo: defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + handler: ({ msg }) => `got ${msg}`, + }), + }; + + test("validates args and calls handler on success", async () => { + const result = await executeFromRegistry(registry, "echo", { msg: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string on validation failure", async () => { + const result = await executeFromRegistry(registry, "echo", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for echo"); + expect(result).toContain("msg"); + }); + + test("throws for unknown tool names", async () => { + await expect(executeFromRegistry(registry, "missing", {})).rejects.toThrow( + /Unknown tool: missing/, + ); + }); + + test("forwards AbortSignal to the handler", async () => { + const handler = vi.fn(async (_args: { x: string }, signal?: AbortSignal) => + signal?.aborted ? "aborted" : "ok", + ); + const reg: ToolRegistry = { + t: defineTool({ + description: "t", + schema: z.object({ x: z.string() }), + handler, + }), + }; + + const controller = new AbortController(); + controller.abort(); + await executeFromRegistry(reg, "t", { x: "hi" }, controller.signal); + + expect(handler).toHaveBeenCalledTimes(1); + expect(handler.mock.calls[0][1]).toBe(controller.signal); + }); +}); + +describe("toolsFromRegistry", () => { + test("produces AgentToolDefinition[] with JSON Schema parameters", () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "Execute a SQL query", + schema: z.object({ + query: z.string().describe("SQL query"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: () => "ok", + }), + }; + + const defs = toolsFromRegistry(registry); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].description).toBe("Execute a SQL query"); + expect(defs[0].parameters).toMatchObject({ + type: "object", + properties: { + query: { type: "string", description: "SQL query" }, + }, + required: ["query"], + }); + expect(defs[0].annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); + + test("preserves dotted names like uploads.list from the registry keys", () => { + const registry: ToolRegistry = { + "uploads.list": defineTool({ + description: "list uploads", + schema: z.object({}), + handler: () => [], + }), + "documents.list": defineTool({ + description: "list documents", + schema: z.object({}), + handler: () => [], + }), + }; + + const names = toolsFromRegistry(registry).map((d) => d.name); + expect(names).toContain("uploads.list"); + expect(names).toContain("documents.list"); + }); + + test("omits annotations when none are provided", () => { + const registry: ToolRegistry = { + plain: defineTool({ + description: "plain", + schema: z.object({}), + handler: () => "ok", + }), + }; + const [def] = toolsFromRegistry(registry); + expect(def.annotations).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/function-tool.test.ts b/packages/appkit/src/plugins/agents/tests/function-tool.test.ts new file mode 100644 index 00000000..8e668d69 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/function-tool.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "vitest"; +import { + functionToolToDefinition, + isFunctionTool, +} from "../tools/function-tool"; + +describe("isFunctionTool", () => { + test("returns true for valid FunctionTool", () => { + expect( + isFunctionTool({ + type: "function", + name: "greet", + execute: async () => "hello", + }), + ).toBe(true); + }); + + test("returns true for minimal FunctionTool", () => { + expect( + isFunctionTool({ + type: "function", + name: "x", + execute: () => "y", + }), + ).toBe(true); + }); + + test("returns false for null", () => { + expect(isFunctionTool(null)).toBe(false); + }); + + test("returns false for non-object", () => { + expect(isFunctionTool("function")).toBe(false); + }); + + test("returns false for wrong type", () => { + expect( + isFunctionTool({ + type: "genie-space", + name: "x", + execute: () => "y", + }), + ).toBe(false); + }); + + test("returns false when execute is missing", () => { + expect(isFunctionTool({ type: "function", name: "x" })).toBe(false); + }); + + test("returns false when name is missing", () => { + expect(isFunctionTool({ type: "function", execute: () => "y" })).toBe( + false, + ); + }); +}); + +describe("functionToolToDefinition", () => { + test("converts a FunctionTool with all fields", () => { + const def = functionToolToDefinition({ + type: "function", + name: "getWeather", + description: "Get current weather", + parameters: { + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }, + execute: async () => "sunny", + }); + + expect(def.name).toBe("getWeather"); + expect(def.description).toBe("Get current weather"); + expect(def.parameters).toEqual({ + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }); + }); + + test("uses name as fallback description", () => { + const def = functionToolToDefinition({ + type: "function", + name: "myTool", + execute: async () => "result", + }); + + expect(def.description).toBe("myTool"); + }); + + test("uses empty object schema when parameters are null", () => { + const def = functionToolToDefinition({ + type: "function", + name: "noParams", + parameters: null, + execute: async () => "ok", + }); + + expect(def.parameters).toEqual({ type: "object", properties: {} }); + }); + + test("uses empty object schema when parameters are omitted", () => { + const def = functionToolToDefinition({ + type: "function", + name: "noParams", + execute: async () => "ok", + }); + + expect(def.parameters).toEqual({ type: "object", properties: {} }); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts b/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts new file mode 100644 index 00000000..d62b266b --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/hosted-tools.test.ts @@ -0,0 +1,131 @@ +import { describe, expect, test } from "vitest"; +import { isHostedTool, resolveHostedTools } from "../tools/hosted-tools"; + +describe("isHostedTool", () => { + test("returns true for genie-space", () => { + expect( + isHostedTool({ type: "genie-space", genie_space: { id: "abc" } }), + ).toBe(true); + }); + + test("returns true for vector_search_index", () => { + expect( + isHostedTool({ + type: "vector_search_index", + vector_search_index: { name: "cat.schema.idx" }, + }), + ).toBe(true); + }); + + test("returns true for custom_mcp_server", () => { + expect( + isHostedTool({ + type: "custom_mcp_server", + custom_mcp_server: { app_name: "my-app", app_url: "my-app-url" }, + }), + ).toBe(true); + }); + + test("returns true for external_mcp_server", () => { + expect( + isHostedTool({ + type: "external_mcp_server", + external_mcp_server: { connection_name: "conn1" }, + }), + ).toBe(true); + }); + + test("returns false for FunctionTool", () => { + expect( + isHostedTool({ type: "function", name: "x", execute: () => "y" }), + ).toBe(false); + }); + + test("returns false for null", () => { + expect(isHostedTool(null)).toBe(false); + }); + + test("returns false for unknown type", () => { + expect(isHostedTool({ type: "unknown" })).toBe(false); + }); + + test("returns false for non-object", () => { + expect(isHostedTool(42)).toBe(false); + }); +}); + +describe("resolveHostedTools", () => { + test("resolves genie-space to correct MCP endpoint", () => { + const configs = resolveHostedTools([ + { type: "genie-space", genie_space: { id: "space123" } }, + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("genie-space123"); + expect(configs[0].url).toBe("/api/2.0/mcp/genie/space123"); + }); + + test("resolves vector_search_index with 3-part name", () => { + const configs = resolveHostedTools([ + { + type: "vector_search_index", + vector_search_index: { name: "catalog.schema.my_index" }, + }, + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("vs-catalog-schema-my_index"); + expect(configs[0].url).toBe( + "/api/2.0/mcp/vector-search/catalog/schema/my_index", + ); + }); + + test("throws for invalid vector_search_index name", () => { + expect(() => + resolveHostedTools([ + { + type: "vector_search_index", + vector_search_index: { name: "bad.name" }, + }, + ]), + ).toThrow("3-part dotted"); + }); + + test("resolves custom_mcp_server", () => { + const configs = resolveHostedTools([ + { + type: "custom_mcp_server", + custom_mcp_server: { app_name: "my-app", app_url: "my-app-endpoint" }, + }, + ]); + + expect(configs[0].name).toBe("my-app"); + expect(configs[0].url).toBe("my-app-endpoint"); + }); + + test("resolves external_mcp_server", () => { + const configs = resolveHostedTools([ + { + type: "external_mcp_server", + external_mcp_server: { connection_name: "conn1" }, + }, + ]); + + expect(configs[0].name).toBe("conn1"); + expect(configs[0].url).toBe("/api/2.0/mcp/external/conn1"); + }); + + test("resolves multiple tools preserving order", () => { + const configs = resolveHostedTools([ + { type: "genie-space", genie_space: { id: "g1" } }, + { + type: "external_mcp_server", + external_mcp_server: { connection_name: "e1" }, + }, + ]); + + expect(configs).toHaveLength(2); + expect(configs[0].name).toBe("genie-g1"); + expect(configs[1].name).toBe("e1"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts new file mode 100644 index 00000000..483fb5f4 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-client.test.ts @@ -0,0 +1,402 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { AppKitMcpClient } from "../tools/mcp-client"; +import type { DnsLookup, McpHostPolicy } from "../tools/mcp-host-policy"; + +const WORKSPACE = "https://test-workspace.cloud.databricks.com"; + +const workspacePolicy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(), + allowLocalhost: false, +}; + +const trustedExternalPolicy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(["mcp.example.com"]), + allowLocalhost: false, +}; + +const publicDnsLookup: DnsLookup = async () => [ + { address: "203.0.113.42", family: 4 }, +]; + +const workspaceAuth = async (): Promise> => ({ + Authorization: "Bearer SP-TOKEN", +}); + +type FetchCall = { + url: string; + init: RequestInit; +}; + +function recordingFetch( + responders: Array<(call: FetchCall) => Response | Promise>, +) { + const calls: FetchCall[] = []; + let n = 0; + const fetchImpl: typeof fetch = async (input, init) => { + const url = typeof input === "string" ? input : (input as URL).toString(); + const call: FetchCall = { url, init: init ?? {} }; + calls.push(call); + const responder = responders[n++] ?? responders[responders.length - 1]; + return Promise.resolve(responder(call)); + }; + return { fetchImpl, calls }; +} + +function jsonResponse(body: unknown, headers: Record = {}) { + return new Response(JSON.stringify(body), { + status: 200, + headers: { "content-type": "application/json", ...headers }, + }); +} + +describe("AppKitMcpClient — host allowlist", () => { + let authSpy: ReturnType; + + beforeEach(() => { + authSpy = vi.fn(workspaceAuth); + }); + + test("connect rejects a URL whose host is not allowlisted without making any fetch", async () => { + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + await expect( + client.connect({ name: "evil", url: "https://attacker.example.com/mcp" }), + ).rejects.toThrow(/attacker\.example\.com/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect rejects plaintext http:// for remote hosts", async () => { + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient( + WORKSPACE, + authSpy, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + await expect( + client.connect({ name: "plain", url: "http://mcp.example.com/mcp" }), + ).rejects.toThrow(/plaintext http/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect rejects a URL whose DNS resolves to a blocked IP and never sends SP token", async () => { + const ssrfLookup: DnsLookup = async () => [ + { address: "169.254.169.254", family: 4 }, + ]; + const policy: McpHostPolicy = { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(["evil.example.com"]), + allowLocalhost: false, + }; + const { fetchImpl, calls } = recordingFetch([() => jsonResponse({})]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, policy, { + fetchImpl, + dnsLookup: ssrfLookup, + }); + await expect( + client.connect({ name: "evil", url: "https://evil.example.com/mcp" }), + ).rejects.toThrow(/169\.254\.169\.254/); + expect(calls).toHaveLength(0); + expect(authSpy).not.toHaveBeenCalled(); + }); + + test("connect to same-origin workspace forwards SP token on initialize + tools/list", async () => { + const { fetchImpl, calls } = recordingFetch([ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "echo", description: "Echo" }] }, + }), + ]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + // initialize + notifications/initialized + tools/list all carry SP token + expect(calls.map((c) => c.url)).toEqual([ + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + `${WORKSPACE}/api/2.0/mcp/genie/abc`, + ]); + for (const call of calls) { + const headers = call.init.headers as Record; + expect(headers.Authorization).toBe("Bearer SP-TOKEN"); + } + expect(client.canForwardWorkspaceAuth("genie-1")).toBe(true); + }); + + test("connect to trusted external host does NOT forward SP token on any RPC", async () => { + const { fetchImpl, calls } = recordingFetch([ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "help" }] }, + }), + ]); + const client = new AppKitMcpClient( + WORKSPACE, + authSpy, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + + await client.connect({ name: "ext", url: "https://mcp.example.com/mcp" }); + + for (const call of calls) { + const headers = call.init.headers as Record; + expect(headers.Authorization).toBeUndefined(); + } + expect(authSpy).not.toHaveBeenCalled(); + expect(client.canForwardWorkspaceAuth("ext")).toBe(false); + }); +}); + +describe("AppKitMcpClient — callTool auth scoping", () => { + test("drops caller-supplied OBO token when destination is not workspace-origin", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + trustedExternalPolicy, + { fetchImpl, dnsLookup: publicDnsLookup }, + ); + await client.connect({ name: "ext", url: "https://mcp.example.com/mcp" }); + + const output = await client.callTool( + "mcp.ext.do", + { x: 1 }, + { + Authorization: "Bearer OBO-USER-TOKEN", + }, + ); + expect(output).toBe("ok"); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBeUndefined(); + }); + + test("forwards caller-supplied OBO token when destination is workspace-origin", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + workspacePolicy, + { + fetchImpl, + dnsLookup: publicDnsLookup, + }, + ); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + await client.callTool( + "mcp.genie-1.do", + {}, + { + Authorization: "Bearer OBO-USER-TOKEN", + }, + ); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBe("Bearer OBO-USER-TOKEN"); + }); + + test("falls back to SP auth when no OBO override is provided and destination is workspace", async () => { + const authSpy = vi.fn(workspaceAuth); + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { + "mcp-session-id": "sess-1", + }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "do" }] }, + }), + ]; + const callResponder = () => + jsonResponse({ + jsonrpc: "2.0", + id: 4, + result: { content: [{ type: "text", text: "ok" }] }, + }); + const { fetchImpl, calls } = recordingFetch([ + ...connectResponders, + callResponder, + ]); + const client = new AppKitMcpClient(WORKSPACE, authSpy, workspacePolicy, { + fetchImpl, + dnsLookup: publicDnsLookup, + }); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + await client.callTool("mcp.genie-1.do", {}, undefined); + + const toolCall = calls[calls.length - 1]; + const headers = toolCall.init.headers as Record; + expect(headers.Authorization).toBe("Bearer SP-TOKEN"); + }); +}); + +describe("AppKitMcpClient — caller abort signal composition", () => { + test("callTool's fetch aborts when the caller signal fires", async () => { + const connectResponders = [ + () => + jsonResponse( + { jsonrpc: "2.0", id: 1, result: {} }, + { "mcp-session-id": "sess-1" }, + ), + () => jsonResponse({ jsonrpc: "2.0", result: null }), + () => + jsonResponse({ + jsonrpc: "2.0", + id: 3, + result: { tools: [{ name: "slow" }] }, + }), + ]; + const callResponder = (call: FetchCall): Promise => { + const signal = call.init.signal as AbortSignal | undefined; + return new Promise((_, reject) => { + if (signal?.aborted) { + reject( + new DOMException( + signal.reason?.toString() ?? "aborted", + "AbortError", + ), + ); + return; + } + signal?.addEventListener( + "abort", + () => { + reject( + new DOMException( + signal.reason?.toString() ?? "aborted", + "AbortError", + ), + ); + }, + { once: true }, + ); + }); + }; + const { fetchImpl } = recordingFetch([...connectResponders, callResponder]); + const client = new AppKitMcpClient( + WORKSPACE, + workspaceAuth, + workspacePolicy, + { + fetchImpl, + dnsLookup: publicDnsLookup, + }, + ); + await client.connect({ + name: "genie-1", + url: `${WORKSPACE}/api/2.0/mcp/genie/abc`, + }); + + const controller = new AbortController(); + const pending = client + .callTool("mcp.genie-1.slow", {}, undefined, controller.signal) + .catch((e) => e); + // Let the fetch start + register its abort listener before we abort. + await new Promise((r) => setTimeout(r, 10)); + controller.abort(new Error("user cancelled")); + const error = (await pending) as Error; + expect(error).toBeInstanceOf(Error); + expect(error.name).toBe("AbortError"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts new file mode 100644 index 00000000..06d98627 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-host-policy.test.ts @@ -0,0 +1,354 @@ +import { describe, expect, test, vi } from "vitest"; +import { + assertResolvedHostSafe, + buildMcpHostPolicy, + checkMcpUrl, + type DnsLookup, + isBlockedIp, + isLoopbackHost, + type McpHostPolicy, + type McpHostPolicyConfig, +} from "../tools/mcp-host-policy"; + +function stubLookup( + addresses: Array<{ address: string; family?: number }>, +): DnsLookup { + return vi + .fn() + .mockResolvedValue(addresses.map((a) => ({ family: 4, ...a }))); +} + +function failingLookup(message: string): DnsLookup { + return vi.fn().mockRejectedValue(new Error(message)); +} + +const WORKSPACE = "https://test-workspace.cloud.databricks.com"; + +function policy(overrides: Partial = {}): McpHostPolicy { + return { + workspaceHostname: "test-workspace.cloud.databricks.com", + trustedHosts: new Set(), + allowLocalhost: false, + ...overrides, + }; +} + +describe("buildMcpHostPolicy", () => { + test("extracts hostname from workspace URL", () => { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.workspaceHostname).toBe("test-workspace.cloud.databricks.com"); + }); + + test("lowercases and trims trustedHosts", () => { + const p = buildMcpHostPolicy( + { trustedHosts: ["Example.COM", " corp.internal ", "mcp.example.com"] }, + WORKSPACE, + ); + expect(p.trustedHosts).toEqual( + new Set(["example.com", "corp.internal", "mcp.example.com"]), + ); + }); + + test("allowLocalhost defaults to false in production", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.allowLocalhost).toBe(false); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("allowLocalhost defaults to true outside production", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "development"; + try { + const p = buildMcpHostPolicy(undefined, WORKSPACE); + expect(p.allowLocalhost).toBe(true); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("allowLocalhost respects explicit override", () => { + const prev = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + const cfg: McpHostPolicyConfig = { allowLocalhost: true }; + const p = buildMcpHostPolicy(cfg, WORKSPACE); + expect(p.allowLocalhost).toBe(true); + } finally { + process.env.NODE_ENV = prev; + } + }); + + test("throws on invalid workspace host", () => { + expect(() => buildMcpHostPolicy(undefined, "not-a-url")).toThrow( + /Invalid workspace host/, + ); + }); +}); + +describe("checkMcpUrl", () => { + test("admits same-origin workspace https URL and forwards auth", () => { + const result = checkMcpUrl(`${WORKSPACE}/api/2.0/mcp/genie/abc`, policy()); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(true); + }); + + test("admits trusted host but does NOT forward workspace auth", () => { + const p = policy({ trustedHosts: new Set(["mcp.example.com"]) }); + const result = checkMcpUrl("https://mcp.example.com/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("rejects host that is neither workspace nor trusted", () => { + const result = checkMcpUrl("https://attacker.example.com/mcp", policy()); + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.reason).toMatch(/attacker\.example\.com/); + expect(result.reason).toMatch(/trustedHosts/); + } + }); + + test("rejects plaintext http:// for remote hosts even when trusted", () => { + const p = policy({ trustedHosts: new Set(["mcp.example.com"]) }); + const result = checkMcpUrl("http://mcp.example.com/mcp", p); + expect(result.ok).toBe(false); + if (!result.ok) expect(result.reason).toMatch(/plaintext http/); + }); + + test("rejects plaintext http://localhost when allowLocalhost is false", () => { + const result = checkMcpUrl("http://localhost:4000/mcp", policy()); + expect(result.ok).toBe(false); + }); + + test("admits http://localhost when allowLocalhost is true, no workspace auth", () => { + const p = policy({ allowLocalhost: true }); + const result = checkMcpUrl("http://localhost:4000/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("admits http://127.0.0.1 when allowLocalhost is true", () => { + const p = policy({ allowLocalhost: true }); + const result = checkMcpUrl("http://127.0.0.1:4000/mcp", p); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(false); + }); + + test("rejects non-http(s) schemes", () => { + for (const url of [ + "file:///etc/passwd", + "ftp://host/x", + "gopher://host/x", + "javascript:alert(1)", + ]) { + const result = checkMcpUrl(url, policy()); + expect(result.ok).toBe(false); + } + }); + + test("rejects obviously invalid URLs", () => { + const result = checkMcpUrl("not-a-url", policy()); + expect(result.ok).toBe(false); + }); + + test("hostname comparison is case-insensitive", () => { + const result = checkMcpUrl( + "https://TEST-Workspace.CLOUD.Databricks.com/mcp", + policy(), + ); + expect(result.ok).toBe(true); + if (result.ok) expect(result.forwardWorkspaceAuth).toBe(true); + }); + + test("rejects same hostname on different scheme (http) even for workspace", () => { + const result = checkMcpUrl( + "http://test-workspace.cloud.databricks.com/mcp", + policy(), + ); + expect(result.ok).toBe(false); + }); +}); + +describe("isBlockedIp", () => { + test("blocks RFC1918 IPv4 ranges", () => { + for (const addr of [ + "10.0.0.1", + "10.255.255.255", + "172.16.0.1", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + ]) { + expect(isBlockedIp(addr, true)).toBe(true); + } + }); + + test("blocks link-local 169.254.0.0/16 (covers cloud metadata 169.254.169.254)", () => { + expect(isBlockedIp("169.254.169.254", true)).toBe(true); + expect(isBlockedIp("169.254.0.1", true)).toBe(true); + }); + + test("blocks CGNAT 100.64.0.0/10", () => { + expect(isBlockedIp("100.64.0.1", true)).toBe(true); + expect(isBlockedIp("100.127.255.255", true)).toBe(true); + }); + + test("blocks 0.0.0.0/8 and multicast/reserved (>= 224.0.0.0)", () => { + expect(isBlockedIp("0.0.0.0", true)).toBe(true); + expect(isBlockedIp("0.1.2.3", true)).toBe(true); + expect(isBlockedIp("224.0.0.1", true)).toBe(true); + expect(isBlockedIp("255.255.255.255", true)).toBe(true); + }); + + test("blocks loopback when allowLocalhost is false", () => { + expect(isBlockedIp("127.0.0.1", false)).toBe(true); + expect(isBlockedIp("127.1.2.3", false)).toBe(true); + expect(isBlockedIp("::1", false)).toBe(true); + }); + + test("permits loopback when allowLocalhost is true", () => { + expect(isBlockedIp("127.0.0.1", true)).toBe(false); + expect(isBlockedIp("::1", true)).toBe(false); + }); + + test("blocks ULA (fc00::/7) and link-local (fe80::/10) IPv6", () => { + expect(isBlockedIp("fc00::1", true)).toBe(true); + expect(isBlockedIp("fd00::1", true)).toBe(true); + expect(isBlockedIp("fe80::1", true)).toBe(true); + }); + + test("blocks the full link-local /10 range fe80::–febf:: (regression: fea0/feb0)", () => { + // fe80::/10 spans 1111 1110 10.. — first hex pair `fe` + second nibble 8..b. + for (const addr of [ + "fe80::1", + "fe90::1", + "fea0::1", // regression: was passing the filter before + "feaf::1", // regression + "feb0::1", // regression + "febf::1", // regression + ]) { + expect(isBlockedIp(addr, true)).toBe(true); + } + // Outside /10 must not be blocked by this rule (belongs to routable-ish + // experimental ranges; nothing else in the module should match either). + expect(isBlockedIp("fec0::1", true)).toBe(false); + }); + + test("blocks IPv4-mapped IPv6 addresses in blocked ranges (dotted form)", () => { + expect(isBlockedIp("::ffff:169.254.169.254", true)).toBe(true); + expect(isBlockedIp("::ffff:10.0.0.1", true)).toBe(true); + }); + + test("blocks IPv4-mapped IPv6 addresses in colon-hex form (regression)", () => { + // ::ffff:a9fe:a9fe is the same destination as ::ffff:169.254.169.254. + // Before the fix this form slipped past the IPv4-mapped branch because + // isIPv4("a9fe:a9fe") is false and no other v6 rule matched. + expect(isBlockedIp("::ffff:a9fe:a9fe", true)).toBe(true); // 169.254.169.254 + expect(isBlockedIp("::ffff:0a00:0001", true)).toBe(true); // 10.0.0.1 + expect(isBlockedIp("::ffff:c0a8:0001", true)).toBe(true); // 192.168.0.1 + // A public IPv4 mapped to colon-hex must still pass through: 8.8.8.8 = 0808:0808 + expect(isBlockedIp("::ffff:0808:0808", true)).toBe(false); + }); + + test("allows public IPv4 and IPv6 addresses", () => { + expect(isBlockedIp("8.8.8.8", false)).toBe(false); + expect(isBlockedIp("1.1.1.1", false)).toBe(false); + expect(isBlockedIp("2001:4860:4860::8888", false)).toBe(false); + }); + + test("treats malformed IP strings as blocked (fail-closed)", () => { + expect(isBlockedIp("10.0.0", true)).toBe(true); + expect(isBlockedIp("abc.def.ghi.jkl", true)).toBe(true); + }); +}); + +describe("isLoopbackHost", () => { + test.each([ + "localhost", + "LOCALHOST", + "127.0.0.1", + "::1", + "[::1]", + "0:0:0:0:0:0:0:1", + ])("recognises %s as loopback", (host) => { + expect(isLoopbackHost(host)).toBe(true); + }); + + test("does not match other hosts", () => { + expect(isLoopbackHost("example.com")).toBe(false); + expect(isLoopbackHost("10.0.0.1")).toBe(false); + }); +}); + +describe("assertResolvedHostSafe", () => { + test("passes workspace hostname when resolved address is public", async () => { + const lookup = stubLookup([{ address: "203.0.113.42" }]); + await expect( + assertResolvedHostSafe( + "test-workspace.cloud.databricks.com", + policy(), + lookup, + ), + ).resolves.toBeUndefined(); + expect(lookup).toHaveBeenCalledWith("test-workspace.cloud.databricks.com", { + all: true, + }); + }); + + test("rejects hostname that resolves to link-local cloud metadata IP", async () => { + const lookup = stubLookup([{ address: "169.254.169.254" }]); + await expect( + assertResolvedHostSafe("evil.example.com", policy(), lookup), + ).rejects.toThrow(/169\.254\.169\.254/); + }); + + test("rejects hostname that resolves to RFC1918 IP", async () => { + const lookup = stubLookup([{ address: "10.0.0.1" }]); + await expect( + assertResolvedHostSafe("internal.example.com", policy(), lookup), + ).rejects.toThrow(/10\.0\.0\.1/); + }); + + test("rejects IP literal in blocked range without DNS lookup", async () => { + const lookup = stubLookup([{ address: "8.8.8.8" }]); + await expect( + assertResolvedHostSafe("169.254.169.254", policy(), lookup), + ).rejects.toThrow(/blocked IP range/); + expect(lookup).not.toHaveBeenCalled(); + }); + + test("rejects plain 'localhost' when allowLocalhost is false", async () => { + await expect(assertResolvedHostSafe("localhost", policy())).rejects.toThrow( + /localhost is not allowed/, + ); + }); + + test("surfaces DNS resolution failures", async () => { + const lookup = failingLookup("ENOTFOUND"); + await expect( + assertResolvedHostSafe("nonexistent.example.com", policy(), lookup), + ).rejects.toThrow(/could not be resolved/); + }); + + test("rejects if any resolved address is blocked (defense against split DNS)", async () => { + const lookup = stubLookup([ + { address: "8.8.8.8" }, + { address: "169.254.169.254" }, + ]); + await expect( + assertResolvedHostSafe("mixed.example.com", policy(), lookup), + ).rejects.toThrow(/169\.254\.169\.254/); + }); + + test("rejects hostname that resolves to empty DNS result", async () => { + const lookup = stubLookup([]); + await expect( + assertResolvedHostSafe("empty.example.com", policy(), lookup), + ).rejects.toThrow(/no DNS addresses/); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts b/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts new file mode 100644 index 00000000..96ad8e38 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/mcp-server-helper.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, test } from "vitest"; +import { + isHostedTool, + mcpServer, + resolveHostedTools, +} from "../tools/hosted-tools"; + +describe("mcpServer()", () => { + test("returns a CustomMcpServerTool with correct shape", () => { + const result = mcpServer("my-app", "https://example.com/mcp"); + + expect(result).toEqual({ + type: "custom_mcp_server", + custom_mcp_server: { + app_name: "my-app", + app_url: "https://example.com/mcp", + }, + }); + }); + + test("isHostedTool recognizes mcpServer() output", () => { + expect(isHostedTool(mcpServer("x", "y"))).toBe(true); + }); + + test("resolveHostedTools resolves mcpServer() output to an endpoint config", () => { + const configs = resolveHostedTools([ + mcpServer("vector-search", "https://host/mcp/vs"), + ]); + + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("vector-search"); + expect(configs[0].url).toBe("https://host/mcp/vs"); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts b/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts new file mode 100644 index 00000000..fb81493e --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/sql-policy.test.ts @@ -0,0 +1,227 @@ +import { describe, expect, test } from "vitest"; +import { + assertReadOnlySql, + classifyReadOnly, + ReadOnlySqlViolation, +} from "../tools/sql-policy"; + +function ok(sql: string) { + const result = classifyReadOnly(sql); + if (!result.readOnly) { + throw new Error( + `Expected readOnly=true for ${JSON.stringify(sql)}, got reason: ${result.reason}`, + ); + } + return result; +} + +function rejected(sql: string) { + const result = classifyReadOnly(sql); + if (result.readOnly) { + throw new Error( + `Expected readOnly=false for ${JSON.stringify(sql)}, got readOnly=true`, + ); + } + return result; +} + +describe("classifyReadOnly: plain reads are admitted", () => { + test.each([ + "SELECT 1", + "select 1", + "SELECT * FROM users", + "SELECT * FROM main.sales.orders WHERE created_at > now() - interval '7 days'", + "SELECT COUNT(*) FROM main.sales.orders", + "WITH a AS (SELECT 1) SELECT * FROM a", + "WITH RECURSIVE t AS (SELECT 1) SELECT * FROM t", + "SHOW TABLES", + "SHOW TABLES IN main.sales", + "DESCRIBE EXTENDED main.sales.orders", + "DESC main.sales.orders", + "EXPLAIN SELECT 1", + "EXPLAIN ANALYZE SELECT 1", + ])("admits %s", (sql) => { + expect(ok(sql).statements).toBe(1); + }); +}); + +describe("classifyReadOnly: writes are rejected", () => { + test.each([ + ["DROP TABLE users", "DROP"], + ["UPDATE users SET email = 'x@y.com'", "UPDATE"], + ["DELETE FROM orders WHERE id = 1", "DELETE"], + ["INSERT INTO x VALUES (1)", "INSERT"], + ["CREATE TABLE x (id INT)", "CREATE"], + ["ALTER TABLE x ADD COLUMN y INT", "ALTER"], + ["TRUNCATE TABLE orders", "TRUNCATE"], + ["GRANT SELECT ON t TO u", "GRANT"], + ["REVOKE ALL ON t FROM u", "REVOKE"], + ["CALL sp_do_thing()", "CALL"], + ["COPY t FROM '/tmp/x'", "COPY"], + ["MERGE INTO t USING s", "MERGE"], + ["REFRESH TABLE t", "REFRESH"], + ["VACUUM t", "VACUUM"], + ])("rejects %s", (sql, keyword) => { + const result = rejected(sql); + expect(result.reason).toContain(keyword); + }); +}); + +describe("classifyReadOnly: stacked statements", () => { + test("rejects SELECT followed by DROP", () => { + const result = rejected("SELECT 1; DROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("rejects DROP followed by SELECT (write comes first)", () => { + const result = rejected("DROP TABLE x; SELECT 1"); + expect(result.reason).toMatch(/DROP/); + }); + + test("admits multiple SELECTs", () => { + expect(ok("SELECT 1; SELECT 2").statements).toBe(2); + }); + + test("admits trailing semicolon on single statement", () => { + expect(ok("SELECT 1;").statements).toBe(1); + }); + + test("admits SELECT, SHOW, DESCRIBE batch", () => { + const result = ok("SELECT 1; SHOW TABLES; DESCRIBE x;"); + expect(result.statements).toBe(3); + }); +}); + +describe("classifyReadOnly: comment handling", () => { + test("admits SELECT with line comment hiding a write keyword", () => { + ok("SELECT 1 -- DROP TABLE x\n"); + }); + + test("admits SELECT preceded by line comment with write keyword", () => { + ok("-- DROP TABLE x\nSELECT 1"); + }); + + test("admits SELECT with block comment containing stacked write", () => { + ok("SELECT 1 /* ; DROP TABLE x */"); + }); + + test("handles nested block comments (PostgreSQL style)", () => { + ok("SELECT 1 /* outer /* inner */ still inside */"); + }); + + test("rejects when write is outside the comment", () => { + const result = rejected("/* SELECT 1 */ DROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("empty after stripping comments is rejected", () => { + rejected("-- only a comment"); + rejected("/* nothing */"); + }); +}); + +describe("classifyReadOnly: string literal handling", () => { + test("admits SELECT with write keyword inside single-quoted string", () => { + ok("SELECT 'DROP TABLE x' AS msg"); + }); + + test("admits SELECT with semicolon inside single-quoted string", () => { + ok("SELECT 'value; DROP TABLE x' AS msg"); + }); + + test("admits SELECT with doubled-quote escape", () => { + ok("SELECT 'it''s ok; DROP' AS msg"); + }); + + test("admits SELECT with backslash escape inside string", () => { + ok("SELECT E'line\\'s end; DROP' AS msg"); + }); + + test("admits SELECT with dollar-quoted string hiding a write", () => { + ok("SELECT $body$ arbitrary ; DROP TABLE x $body$ AS msg"); + }); + + test("admits SELECT with untagged dollar quote", () => { + ok("SELECT $$hello; DROP$$ AS msg"); + }); + + test("admits SELECT with ANSI double-quoted identifier named drop", () => { + ok('SELECT * FROM "drop"'); + }); + + test("admits SELECT with doubled-quote inside ANSI identifier", () => { + ok('SELECT * FROM "weird""name"'); + }); + + test("admits SELECT with backtick identifier (Databricks)", () => { + ok("SELECT * FROM `my table`"); + }); +}); + +describe("classifyReadOnly: degenerate input", () => { + test("rejects empty string", () => { + rejected(""); + }); + + test("rejects whitespace-only", () => { + rejected(" \n\t "); + }); + + test("rejects semicolons only", () => { + rejected(";;;"); + }); + + test("rejects non-SQL garbage", () => { + rejected("-- this is just a comment\n-- nothing else"); + rejected("random garbage text"); + }); + + test("rejects a single empty statement between two selects", () => { + // "SELECT 1;; SELECT 2" — the middle empty statement is dropped by + // splitter; the surviving two statements are both SELECT, admitted. + ok("SELECT 1;; SELECT 2"); + }); +}); + +describe("classifyReadOnly: evasion-resistance", () => { + test("cannot hide DROP after a comment-ended newline", () => { + const result = rejected("-- intent\nDROP TABLE x"); + expect(result.reason).toMatch(/DROP/); + }); + + test("cannot hide DROP via concatenated strings (strings end cleanly)", () => { + rejected("'SELECT 1'; DROP TABLE x"); + }); + + test("bare DROP after unclosed string is still considered part of the string (defensive)", () => { + // An unclosed single quote eats the rest of the input — classifier + // sees the whole thing as one stripped, empty-ish statement and rejects. + rejected("SELECT 'unterminated ; DROP TABLE x"); + }); + + test("dollar-quoted literal with malicious tag is handled", () => { + ok("SELECT $tag$ DROP $tag$ AS harmless"); + }); + + test("mismatched dollar-quote tag is treated as unterminated", () => { + rejected("SELECT $a$ DROP TABLE x $b$"); + }); +}); + +describe("assertReadOnlySql", () => { + test("returns void on read-only SQL", () => { + expect(() => assertReadOnlySql("SELECT 1")).not.toThrow(); + }); + + test("throws ReadOnlySqlViolation with descriptive message on writes", () => { + expect(() => assertReadOnlySql("DROP TABLE x")).toThrow( + ReadOnlySqlViolation, + ); + try { + assertReadOnlySql("DROP TABLE x"); + } catch (e) { + expect((e as Error).message).toMatch(/SQL read-only policy violation/); + expect((e as Error).message).toMatch(/DROP/); + } + }); +}); diff --git a/packages/appkit/src/plugins/agents/tests/tool.test.ts b/packages/appkit/src/plugins/agents/tests/tool.test.ts new file mode 100644 index 00000000..3d47f3a9 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tests/tool.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "vitest"; +import { z } from "zod"; +import { formatZodError, tool } from "../tools/tool"; + +describe("tool()", () => { + test("produces a FunctionTool with JSON Schema parameters from the Zod schema", () => { + const weather = tool({ + name: "get_weather", + description: "Get the current weather for a city", + schema: z.object({ + city: z.string().describe("City name"), + }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + expect(weather.type).toBe("function"); + expect(weather.name).toBe("get_weather"); + expect(weather.description).toBe("Get the current weather for a city"); + expect(weather.parameters).toMatchObject({ + type: "object", + properties: { + city: { type: "string", description: "City name" }, + }, + required: ["city"], + }); + }); + + test("execute receives typed args on valid input", async () => { + const echo = tool({ + name: "echo", + schema: z.object({ message: z.string() }), + execute: async ({ message }) => { + const _typed: string = message; + return `got ${_typed}`; + }, + }); + + const result = await echo.execute({ message: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string (does not throw) when args are invalid", async () => { + const weather = tool({ + name: "get_weather", + schema: z.object({ city: z.string() }), + execute: async ({ city }) => `Sunny in ${city}`, + }); + + const result = await weather.execute({}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for get_weather"); + expect(result).toContain("city"); + }); + + test("joins multiple validation errors with '; '", async () => { + const t = tool({ + name: "multi", + schema: z.object({ a: z.string(), b: z.number() }), + execute: async () => "ok", + }); + + const result = await t.execute({}); + expect(result).toContain("a:"); + expect(result).toContain("b:"); + expect(result).toContain(";"); + }); + + test("optional fields validate when absent", async () => { + const t = tool({ + name: "opt", + schema: z.object({ note: z.string().optional() }), + execute: async ({ note }) => note ?? "(no note)", + }); + + expect(await t.execute({})).toBe("(no note)"); + expect(await t.execute({ note: "hello" })).toBe("hello"); + }); + + test("description falls back to the tool name when omitted", () => { + const t = tool({ + name: "my_tool", + schema: z.object({}), + execute: async () => "ok", + }); + + expect(t.description).toBe("my_tool"); + expect(t.parameters).toBeDefined(); + }); +}); + +describe("formatZodError", () => { + test("formats a single issue with the tool name", () => { + const schema = z.object({ city: z.string() }); + const result = schema.safeParse({}); + if (result.success) throw new Error("expected failure"); + + const msg = formatZodError(result.error, "get_weather"); + expect(msg).toMatch(/^Invalid arguments for get_weather: /); + expect(msg).toContain("city:"); + }); + + test("joins multiple issues with '; '", () => { + const schema = z.object({ a: z.string(), b: z.number() }); + const result = schema.safeParse({}); + if (result.success) throw new Error("expected failure"); + + const msg = formatZodError(result.error, "t"); + expect(msg.split(";").length).toBeGreaterThanOrEqual(2); + }); +}); diff --git a/packages/appkit/src/plugins/agents/tools/define-tool.ts b/packages/appkit/src/plugins/agents/tools/define-tool.ts new file mode 100644 index 00000000..dc269ba6 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/define-tool.ts @@ -0,0 +1,94 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import type { z } from "zod"; +import { toToolJSONSchema } from "./json-schema"; +import { formatZodError } from "./tool"; + +/** + * Single-tool entry for a plugin's internal tool registry. + * + * Plugins collect these into a `Record` keyed by the tool's + * public name and dispatch via `executeFromRegistry`. + */ +export interface ToolEntry { + description: string; + schema: S; + annotations?: ToolAnnotations; + /** + * Whether this tool is eligible for auto-inheritance into markdown or + * code-defined agents that enable `autoInheritTools`. Defaults to `false` + * (safe-by-default) — plugin authors must explicitly opt a tool in if they + * consider it safe enough to appear in every agent's tool record without an + * explicit `tools:` declaration. Destructive or privilege-sensitive tools + * should leave this unset so that they only reach agents that wire them + * explicitly (via `tools:`, `toolkits:`, or `fromPlugin({ only: [...] })`). + */ + autoInheritable?: boolean; + handler: ( + args: z.infer, + signal?: AbortSignal, + ) => unknown | Promise; +} + +export type ToolRegistry = Record; + +/** + * Defines a single tool entry for a plugin's internal registry. + * + * The generic `S` flows from `schema` through to the `handler` callback so + * `args` is fully typed from the Zod schema. Names are assigned by the + * registry key, so they are not repeated inside the entry. + */ +export function defineTool( + config: ToolEntry, +): ToolEntry { + return config; +} + +/** + * Validates tool-call arguments against the entry's schema and invokes its + * handler. On validation failure, returns an LLM-friendly error string + * (matching the behavior of `tool()`) rather than throwing, so the model + * can self-correct on its next turn. + */ +export async function executeFromRegistry( + registry: ToolRegistry, + name: string, + args: unknown, + signal?: AbortSignal, +): Promise { + const entry = registry[name]; + if (!entry) { + throw new Error(`Unknown tool: ${name}`); + } + const parsed = entry.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, name); + } + return entry.handler(parsed.data, signal); +} + +/** + * Produces the `AgentToolDefinition[]` a ToolProvider exposes to the LLM, + * deriving `parameters` JSON Schema from each entry's Zod schema. + * + * Tool names come from registry keys (supports dotted names like + * `uploads.list` for dynamic plugins). + */ +export function toolsFromRegistry( + registry: ToolRegistry, +): AgentToolDefinition[] { + return Object.entries(registry).map(([name, entry]) => { + const parameters = toToolJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + const def: AgentToolDefinition = { + name, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + return def; + }); +} diff --git a/packages/appkit/src/plugins/agents/tools/function-tool.ts b/packages/appkit/src/plugins/agents/tools/function-tool.ts new file mode 100644 index 00000000..8ce634e0 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/function-tool.ts @@ -0,0 +1,33 @@ +import type { AgentToolDefinition } from "shared"; + +export interface FunctionTool { + type: "function"; + name: string; + description?: string | null; + parameters?: Record | null; + strict?: boolean | null; + execute: (args: Record) => Promise | string; +} + +export function isFunctionTool(value: unknown): value is FunctionTool { + if (typeof value !== "object" || value === null) return false; + const obj = value as Record; + return ( + obj.type === "function" && + typeof obj.name === "string" && + typeof obj.execute === "function" + ); +} + +export function functionToolToDefinition( + tool: FunctionTool, +): AgentToolDefinition { + return { + name: tool.name, + description: tool.description ?? tool.name, + parameters: (tool.parameters as AgentToolDefinition["parameters"]) ?? { + type: "object", + properties: {}, + }, + }; +} diff --git a/packages/appkit/src/plugins/agents/tools/hosted-tools.ts b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts new file mode 100644 index 00000000..bce70c4f --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/hosted-tools.ts @@ -0,0 +1,102 @@ +export interface GenieTool { + type: "genie-space"; + genie_space: { id: string }; +} + +export interface VectorSearchIndexTool { + type: "vector_search_index"; + vector_search_index: { name: string }; +} + +export interface CustomMcpServerTool { + type: "custom_mcp_server"; + custom_mcp_server: { app_name: string; app_url: string }; +} + +export interface ExternalMcpServerTool { + type: "external_mcp_server"; + external_mcp_server: { connection_name: string }; +} + +export type HostedTool = + | GenieTool + | VectorSearchIndexTool + | CustomMcpServerTool + | ExternalMcpServerTool; + +const HOSTED_TOOL_TYPES = new Set([ + "genie-space", + "vector_search_index", + "custom_mcp_server", + "external_mcp_server", +]); + +export function isHostedTool(value: unknown): value is HostedTool { + if (typeof value !== "object" || value === null) return false; + const obj = value as Record; + return typeof obj.type === "string" && HOSTED_TOOL_TYPES.has(obj.type); +} + +export interface McpEndpointConfig { + name: string; + /** Absolute URL or path relative to workspace host */ + url: string; +} + +/** + * Resolves HostedTool configs into MCP endpoint configurations + * that the MCP client can connect to. + */ +function resolveHostedTool(tool: HostedTool): McpEndpointConfig { + switch (tool.type) { + case "genie-space": + return { + name: `genie-${tool.genie_space.id}`, + url: `/api/2.0/mcp/genie/${tool.genie_space.id}`, + }; + case "vector_search_index": { + const parts = tool.vector_search_index.name.split("."); + if (parts.length !== 3) { + throw new Error( + `vector_search_index name must be 3-part dotted (catalog.schema.index), got: ${tool.vector_search_index.name}`, + ); + } + return { + name: `vs-${parts.join("-")}`, + url: `/api/2.0/mcp/vector-search/${parts[0]}/${parts[1]}/${parts[2]}`, + }; + } + case "custom_mcp_server": + return { + name: tool.custom_mcp_server.app_name, + url: tool.custom_mcp_server.app_url, + }; + case "external_mcp_server": + return { + name: tool.external_mcp_server.connection_name, + url: `/api/2.0/mcp/external/${tool.external_mcp_server.connection_name}`, + }; + } +} + +export function resolveHostedTools(tools: HostedTool[]): McpEndpointConfig[] { + return tools.map(resolveHostedTool); +} + +/** + * Factory for declaring a custom MCP server tool. + * + * Replaces the verbose `{ type: "custom_mcp_server", custom_mcp_server: { app_name, app_url } }` + * wrapper with a concise positional call. + * + * Example: + * ```ts + * mcpServer("my-app", "https://my-app.databricksapps.com/mcp") + * ``` + */ +export function mcpServer(name: string, url: string): CustomMcpServerTool { + return { + type: "custom_mcp_server", + custom_mcp_server: { app_name: name, app_url: url }, + }; +} diff --git a/packages/appkit/src/plugins/agents/tools/index.ts b/packages/appkit/src/plugins/agents/tools/index.ts new file mode 100644 index 00000000..7b779d1c --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/index.ts @@ -0,0 +1,20 @@ +export { + defineTool, + executeFromRegistry, + type ToolEntry, + type ToolRegistry, + toolsFromRegistry, +} from "./define-tool"; +export { + type FunctionTool, + functionToolToDefinition, + isFunctionTool, +} from "./function-tool"; +export { + type HostedTool, + isHostedTool, + mcpServer, + resolveHostedTools, +} from "./hosted-tools"; +export { AppKitMcpClient } from "./mcp-client"; +export { type ToolConfig, tool } from "./tool"; diff --git a/packages/appkit/src/plugins/agents/tools/json-schema.ts b/packages/appkit/src/plugins/agents/tools/json-schema.ts new file mode 100644 index 00000000..c5c10dbf --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/json-schema.ts @@ -0,0 +1,20 @@ +import { toJSONSchema, type z } from "zod"; + +/** + * Converts a Zod schema to JSON Schema suitable for an LLM tool-call + * `parameters` field. + * + * Wraps `zod`'s `toJSONSchema()` and strips the top-level `$schema` annotation + * that Zod v4 emits by default (e.g. `"https://json-schema.org/draft/..."`). + * The Databricks Mosaic serving endpoint forwards tool schemas to Google's + * Gemini `function_declarations` format, which rejects any top-level key it + * doesn't explicitly recognize — including `$schema` — with a 400 + * `Invalid JSON payload received. Unknown name "$schema"` error. Other LLM + * providers either ignore the field or also trip on it, so stripping here is + * safe across backends. + */ +export function toToolJSONSchema(schema: z.ZodType): Record { + const raw = toJSONSchema(schema) as Record; + const { $schema: _ignored, ...rest } = raw; + return rest; +} diff --git a/packages/appkit/src/plugins/agents/tools/mcp-client.ts b/packages/appkit/src/plugins/agents/tools/mcp-client.ts new file mode 100644 index 00000000..377ff8a4 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/mcp-client.ts @@ -0,0 +1,375 @@ +import type { AgentToolDefinition } from "shared"; +import { createLogger } from "../../../logging/logger"; +import type { McpEndpointConfig } from "./hosted-tools"; +import { + assertResolvedHostSafe, + checkMcpUrl, + type DnsLookup, + type McpHostPolicy, +} from "./mcp-host-policy"; + +const logger = createLogger("agent:mcp"); + +interface JsonRpcRequest { + jsonrpc: "2.0"; + id: number; + method: string; + params?: Record; +} + +interface JsonRpcResponse { + jsonrpc: "2.0"; + id: number; + result?: unknown; + error?: { code: number; message: string; data?: unknown }; +} + +interface McpToolSchema { + name: string; + description?: string; + inputSchema?: Record; +} + +interface McpToolCallResult { + content: Array<{ type: string; text?: string }>; + isError?: boolean; +} + +interface McpServerConnection { + config: McpEndpointConfig; + resolvedUrl: string; + /** + * Whether workspace auth (SP / OBO) may be forwarded to this endpoint's URL. + * Decided at `connect()` time via {@link McpHostPolicy} and cached for the + * lifetime of the connection. + */ + forwardWorkspaceAuth: boolean; + tools: Map; +} + +/** + * Lightweight MCP client for Databricks-hosted MCP servers. + * + * Uses raw fetch() with JSON-RPC 2.0 over HTTP — no @modelcontextprotocol/sdk + * or LangChain dependency. Supports the Streamable HTTP transport (POST with + * JSON-RPC request, single JSON-RPC response). + * + * All outbound URLs are gated by an {@link McpHostPolicy}: unallowlisted hosts + * are rejected before the first byte is sent, and workspace credentials are + * only forwarded to the same-origin workspace. See `mcp-host-policy.ts`. + */ +export class AppKitMcpClient { + private connections = new Map(); + private sessionIds = new Map(); + private requestId = 0; + private closed = false; + + constructor( + private workspaceHost: string, + private authenticate: () => Promise>, + private policy: McpHostPolicy, + private options: { dnsLookup?: DnsLookup; fetchImpl?: typeof fetch } = {}, + ) {} + + async connectAll(endpoints: McpEndpointConfig[]): Promise { + const results = await Promise.allSettled( + endpoints.map((ep) => this.connect(ep)), + ); + for (let i = 0; i < results.length; i++) { + if (results[i].status === "rejected") { + logger.error( + "Failed to connect MCP server %s: %O", + endpoints[i].name, + (results[i] as PromiseRejectedResult).reason, + ); + } + } + } + + private resolveUrl(endpoint: McpEndpointConfig): string { + if ( + endpoint.url.startsWith("http://") || + endpoint.url.startsWith("https://") + ) { + return endpoint.url; + } + return `${this.workspaceHost}${endpoint.url}`; + } + + async connect(endpoint: McpEndpointConfig): Promise { + const resolvedUrl = this.resolveUrl(endpoint); + const check = checkMcpUrl(resolvedUrl, this.policy); + if (!check.ok) { + throw new Error( + `MCP endpoint '${endpoint.name}' refused at connect: ${check.reason}`, + ); + } + await assertResolvedHostSafe( + check.url.hostname, + this.policy, + this.options.dnsLookup, + ); + + logger.info( + "Connecting to MCP server: %s at %s (forwardWorkspaceAuth=%s)", + endpoint.name, + resolvedUrl, + check.forwardWorkspaceAuth, + ); + + const initResponse = await this.sendRpc( + resolvedUrl, + "initialize", + { + protocolVersion: "2025-03-26", + capabilities: {}, + clientInfo: { name: "appkit-agent", version: "0.1.0" }, + }, + { forwardWorkspaceAuth: check.forwardWorkspaceAuth }, + ); + + if (initResponse.sessionId) { + this.sessionIds.set(endpoint.name, initResponse.sessionId); + } + const sessionId = this.sessionIds.get(endpoint.name); + + await this.sendNotification(resolvedUrl, "notifications/initialized", { + sessionId, + forwardWorkspaceAuth: check.forwardWorkspaceAuth, + }); + + const listResponse = await this.sendRpc( + resolvedUrl, + "tools/list", + {}, + { sessionId, forwardWorkspaceAuth: check.forwardWorkspaceAuth }, + ); + const toolList = + (listResponse.result as { tools?: McpToolSchema[] })?.tools ?? []; + + const tools = new Map(); + for (const tool of toolList) { + tools.set(tool.name, tool); + } + + this.connections.set(endpoint.name, { + config: endpoint, + resolvedUrl, + forwardWorkspaceAuth: check.forwardWorkspaceAuth, + tools, + }); + logger.info( + "Connected to MCP server %s: %d tools available", + endpoint.name, + tools.size, + ); + } + + getAllToolDefinitions(): AgentToolDefinition[] { + const defs: AgentToolDefinition[] = []; + for (const [serverName, conn] of this.connections) { + for (const [toolName, schema] of conn.tools) { + defs.push({ + name: `mcp.${serverName}.${toolName}`, + description: schema.description ?? toolName, + parameters: + (schema.inputSchema as AgentToolDefinition["parameters"]) ?? { + type: "object", + properties: {}, + }, + }); + } + } + return defs; + } + + /** + * Whether the named MCP server may receive workspace-scoped auth headers + * (e.g., an OBO bearer token from an end-user request). Callers should gate + * auth-forwarding decisions on this to prevent credential exfiltration to + * non-workspace hosts. + */ + canForwardWorkspaceAuth(serverName: string): boolean { + return this.connections.get(serverName)?.forwardWorkspaceAuth ?? false; + } + + async callTool( + qualifiedName: string, + args: unknown, + authHeaders?: Record, + callerSignal?: AbortSignal, + ): Promise { + const parts = qualifiedName.split("."); + if (parts.length < 3 || parts[0] !== "mcp") { + throw new Error(`Invalid MCP tool name: ${qualifiedName}`); + } + const serverName = parts[1]; + const toolName = parts.slice(2).join("."); + + const conn = this.connections.get(serverName); + if (!conn) { + throw new Error(`MCP server not connected: ${serverName}`); + } + + const sessionId = this.sessionIds.get(serverName); + // authHeaders are caller-supplied credentials (typically the OBO token). + // Only honor them if the destination URL was admitted with + // forwardWorkspaceAuth=true at connect time. + const scopedAuthOverride = conn.forwardWorkspaceAuth + ? authHeaders + : undefined; + + const rpcResult = await this.sendRpc( + conn.resolvedUrl, + "tools/call", + { name: toolName, arguments: args }, + { + authOverride: scopedAuthOverride, + sessionId, + forwardWorkspaceAuth: conn.forwardWorkspaceAuth, + callerSignal, + }, + ); + const result = rpcResult.result as McpToolCallResult; + + if (result.isError) { + const errText = (result.content ?? []) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n"); + throw new Error(errText || "MCP tool call failed"); + } + + return (result.content ?? []) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n"); + } + + async close(): Promise { + this.closed = true; + this.connections.clear(); + this.sessionIds.clear(); + } + + private async sendRpc( + url: string, + method: string, + params?: Record, + options?: { + authOverride?: Record; + sessionId?: string; + forwardWorkspaceAuth?: boolean; + /** + * Optional external abort signal (typically the agent's stream signal). + * Composed with the built-in 30 s timeout so `/cancel` or agent-run + * shutdown immediately propagates to the MCP fetch rather than waiting + * for the remote server to respond. + */ + callerSignal?: AbortSignal; + }, + ): Promise<{ result: unknown; sessionId?: string }> { + if (this.closed) throw new Error("MCP client is closed"); + + const request: JsonRpcRequest = { + jsonrpc: "2.0", + id: ++this.requestId, + method, + ...(params && { params }), + }; + + const authHeaders = await this.resolveAuthHeaders(options); + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...authHeaders, + }; + if (options?.sessionId) { + headers["Mcp-Session-Id"] = options.sessionId; + } + + const fetchImpl = this.options.fetchImpl ?? fetch; + const signals: AbortSignal[] = [AbortSignal.timeout(30_000)]; + if (options?.callerSignal) signals.push(options.callerSignal); + const response = await fetchImpl(url, { + method: "POST", + headers, + body: JSON.stringify(request), + signal: signals.length > 1 ? AbortSignal.any(signals) : signals[0], + }); + + if (!response.ok) { + throw new Error( + `MCP request to ${method} failed: ${response.status} ${response.statusText}`, + ); + } + + const contentType = response.headers.get("content-type") ?? ""; + let json: JsonRpcResponse; + + if (contentType.includes("text/event-stream")) { + const text = await response.text(); + const lastData = text + .split("\n") + .filter((line) => line.startsWith("data: ")) + .map((line) => line.slice(6)) + .pop(); + if (!lastData) { + throw new Error(`MCP SSE response for ${method} contained no data`); + } + json = JSON.parse(lastData) as JsonRpcResponse; + } else { + json = (await response.json()) as JsonRpcResponse; + } + + if (json.error) { + throw new Error(`MCP error (${json.error.code}): ${json.error.message}`); + } + + const sid = response.headers.get("mcp-session-id") ?? undefined; + return { result: json.result, sessionId: sid }; + } + + private async sendNotification( + url: string, + method: string, + options?: { + sessionId?: string; + forwardWorkspaceAuth?: boolean; + }, + ): Promise { + if (this.closed) return; + + const authHeaders = await this.resolveAuthHeaders(options); + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...authHeaders, + }; + if (options?.sessionId) { + headers["Mcp-Session-Id"] = options.sessionId; + } + + const fetchImpl = this.options.fetchImpl ?? fetch; + await fetchImpl(url, { + method: "POST", + headers, + body: JSON.stringify({ jsonrpc: "2.0", method }), + signal: AbortSignal.timeout(30_000), + }); + } + + /** + * Return the auth headers to send on an outbound request. Workspace auth + * (SP or OBO) is only resolved when `forwardWorkspaceAuth` is true; for + * non-workspace hosts no bearer token is attached. + */ + private async resolveAuthHeaders(options?: { + authOverride?: Record; + forwardWorkspaceAuth?: boolean; + }): Promise> { + if (!options?.forwardWorkspaceAuth) return {}; + if (options.authOverride) return options.authOverride; + return this.authenticate(); + } +} diff --git a/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts b/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts new file mode 100644 index 00000000..d970c83a --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/mcp-host-policy.ts @@ -0,0 +1,299 @@ +import { lookup as defaultLookup } from "node:dns/promises"; +import { isIP, isIPv4 } from "node:net"; + +/** + * DNS lookup function compatible with `dns/promises.lookup(host, { all: true })`. + * Exposed as an injection point so callers (tests, custom DNS resolvers) can + * override the default resolver. + */ +export type DnsLookup = ( + hostname: string, + options: { all: true }, +) => Promise>; + +/** + * Policy that decides whether a given MCP endpoint URL is allowed and whether + * Databricks workspace credentials (SP or OBO) may be forwarded to it. + * + * The default posture is zero-trust: only same-origin workspace URLs receive + * workspace credentials, and all other destinations must be explicitly + * allowlisted by the application developer. Private / link-local IP ranges + * are blocked outright to prevent SSRF into cloud metadata services. + */ +export interface McpHostPolicy { + /** Lowercased hostname of the Databricks workspace (same-origin target). */ + readonly workspaceHostname: string; + /** Additional allowlisted hostnames (lowercased). Workspace auth is NEVER forwarded to these. */ + readonly trustedHosts: ReadonlySet; + /** Permit `http://localhost`, `127.0.0.1`, `::1` URLs. Typically true only in development. */ + readonly allowLocalhost: boolean; +} + +/** + * Config shape accepted by {@link buildMcpHostPolicy}, matching the + * `mcp` field on `AgentsPluginConfig`. + */ +export interface McpHostPolicyConfig { + /** + * Additional hostnames that may host custom MCP servers beyond the same-origin + * workspace. Compared case-insensitively; bare hostnames only (no scheme or + * path). Workspace credentials (SP / OBO) are never forwarded to these hosts — + * they must handle authentication themselves. + */ + trustedHosts?: string[]; + /** + * Allow `http://localhost`, `127.0.0.1`, and `::1` MCP URLs for local + * development. Defaults to `true` when `NODE_ENV !== "production"`, + * otherwise `false`. Workspace credentials are never forwarded to localhost. + */ + allowLocalhost?: boolean; +} + +/** Build an {@link McpHostPolicy} from user config + the resolved workspace URL. */ +export function buildMcpHostPolicy( + config: McpHostPolicyConfig | undefined, + workspaceHost: string, +): McpHostPolicy { + const workspaceHostname = safeHostname(workspaceHost); + if (!workspaceHostname) { + throw new Error( + `Invalid workspace host for MCP policy: ${JSON.stringify(workspaceHost)}`, + ); + } + const trustedHosts = new Set( + (config?.trustedHosts ?? []).map((h) => h.trim().toLowerCase()), + ); + const allowLocalhost = + config?.allowLocalhost ?? process.env.NODE_ENV !== "production"; + return { workspaceHostname, trustedHosts, allowLocalhost }; +} + +type McpUrlCheck = + | { + readonly ok: true; + /** Whether it is safe to forward workspace-scoped credentials (SP/OBO) to this URL. */ + readonly forwardWorkspaceAuth: boolean; + /** Parsed URL for reuse by the caller. */ + readonly url: URL; + } + | { readonly ok: false; readonly reason: string }; + +/** + * Synchronously decide whether an MCP URL is allowed under the given policy + * and whether workspace credentials may be forwarded to it. + * + * Hard rejections: + * - Non-`http(s)` schemes. + * - `http://` unless the host is localhost AND `allowLocalhost` is true. + * - Hosts that are neither same-origin workspace, localhost (if allowed), + * nor in `trustedHosts`. + */ +export function checkMcpUrl( + rawUrl: string, + policy: McpHostPolicy, +): McpUrlCheck { + let url: URL; + try { + url = new URL(rawUrl); + } catch { + return { + ok: false, + reason: `MCP URL is not a valid absolute URL: ${rawUrl}`, + }; + } + + if (url.protocol !== "http:" && url.protocol !== "https:") { + return { + ok: false, + reason: `MCP URL scheme '${url.protocol}' is not allowed (http(s) only): ${rawUrl}`, + }; + } + + const host = url.hostname.toLowerCase(); + const isLoopback = isLoopbackHost(host); + + if (url.protocol === "http:" && !(isLoopback && policy.allowLocalhost)) { + return { + ok: false, + reason: `MCP URL uses plaintext http:// which forwards bearer tokens in cleartext: ${rawUrl}. Use https:// or enable allowLocalhost for a localhost dev server.`, + }; + } + + if (host === policy.workspaceHostname) { + return { ok: true, forwardWorkspaceAuth: true, url }; + } + + if (isLoopback) { + if (!policy.allowLocalhost) { + return { + ok: false, + reason: `MCP URL points to localhost but allowLocalhost is disabled: ${rawUrl}`, + }; + } + return { ok: true, forwardWorkspaceAuth: false, url }; + } + + if (policy.trustedHosts.has(host)) { + return { ok: true, forwardWorkspaceAuth: false, url }; + } + + return { + ok: false, + reason: `MCP host '${host}' is not allowed. Either use a same-origin workspace URL (${policy.workspaceHostname}) or add it to agents({ mcp: { trustedHosts: ['${host}'] } }).`, + }; +} + +/** + * Resolve `hostname` via DNS and assert that none of its addresses fall in a + * blocked IP range (loopback, RFC1918, link-local, CGNAT, cloud metadata). + * + * Throws with a descriptive error if any resolved address is blocked. Pass + * `allowLocalhost: true` to permit `127.0.0.1` / `::1` specifically. + * + * Note: this only guards against hosts that statically resolve to private + * ranges. Full SSRF protection requires socket-level IP pinning after + * resolution (DNS rebinding defense), which is out of scope here. + */ +export async function assertResolvedHostSafe( + hostname: string, + policy: McpHostPolicy, + lookup: DnsLookup = defaultLookup, +): Promise { + const lowered = hostname.toLowerCase(); + + if (isIP(lowered)) { + if (isBlockedIp(lowered, policy.allowLocalhost)) { + throw new Error(`MCP host ${lowered} is in a blocked IP range`); + } + return; + } + + if (lowered === "localhost") { + if (!policy.allowLocalhost) { + throw new Error( + `MCP host localhost is not allowed under the current policy`, + ); + } + return; + } + + let resolved: Array<{ address: string }>; + try { + resolved = await lookup(hostname, { all: true }); + } catch (cause) { + throw new Error( + `MCP host ${hostname} could not be resolved via DNS: ${cause instanceof Error ? cause.message : String(cause)}`, + ); + } + + if (resolved.length === 0) { + throw new Error(`MCP host ${hostname} returned no DNS addresses`); + } + + for (const { address } of resolved) { + if (isBlockedIp(address, policy.allowLocalhost)) { + throw new Error( + `MCP host ${hostname} resolved to blocked address ${address} (private / link-local ranges are not allowed)`, + ); + } + } +} + +/** Whether a raw hostname literal is one of the recognised loopback aliases. */ +export function isLoopbackHost(host: string): boolean { + const lowered = host.toLowerCase(); + return ( + lowered === "localhost" || + lowered === "127.0.0.1" || + lowered === "::1" || + lowered === "[::1]" || + lowered === "0:0:0:0:0:0:0:1" + ); +} + +/** + * Check whether a resolved IP address is in a range that should never receive + * workspace credentials. `allowLocalhost` carves out 127.0.0.0/8 and ::1. + */ +export function isBlockedIp(address: string, allowLocalhost: boolean): boolean { + if (isIPv4(address)) { + return isBlockedIpv4(address, allowLocalhost); + } + if (isIP(address) === 6) { + return isBlockedIpv6(address, allowLocalhost); + } + // Not a recognisable IP literal — fail-closed. + return true; +} + +function isBlockedIpv4(addr: string, allowLocalhost: boolean): boolean { + const parts = addr.split(".").map((p) => Number.parseInt(p, 10)); + if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) { + return true; + } + const [a, b] = parts; + if (a === 0) return true; + if (a === 127) return !allowLocalhost; + if (a === 10) return true; + if (a === 172 && b >= 16 && b <= 31) return true; + if (a === 192 && b === 168) return true; + if (a === 169 && b === 254) return true; + if (a === 100 && b >= 64 && b <= 127) return true; + if (a >= 224) return true; + return false; +} + +function isBlockedIpv6(addr: string, allowLocalhost: boolean): boolean { + const lowered = addr.toLowerCase().replace(/^\[|\]$/g, ""); + + if (lowered === "::") return true; + if (lowered === "::1" || lowered === "0:0:0:0:0:0:0:1") + return !allowLocalhost; + + // IPv4-mapped IPv6: `::ffff:` may be written in dotted form + // (`::ffff:169.254.169.254`) or colon-hex form (`::ffff:a9fe:a9fe`). Both + // route to the same destination, so we must normalise before delegating + // to the IPv4 blocklist. + if (lowered.startsWith("::ffff:")) { + const tail = lowered.slice("::ffff:".length); + if (isIPv4(tail)) return isBlockedIpv4(tail, allowLocalhost); + const hexV4 = hexPairToDottedIpv4(tail); + if (hexV4) return isBlockedIpv4(hexV4, allowLocalhost); + } + + // Unique Local Addresses (fc00::/7) — `fc` and `fd` only. + if (/^f[cd][0-9a-f]{2}:/.test(lowered)) return true; + // Link-local fe80::/10 — the first 10 bits are 1111111010, i.e. the + // second hex nibble must be 8-b. Matches fe80:..–febf:.. + if (/^fe[89ab][0-9a-f]:/.test(lowered)) return true; + // Multicast ff00::/8. + if (lowered.startsWith("ff")) return true; + return false; +} + +/** + * Parse the trailing two hex groups of an IPv4-mapped IPv6 address written + * in colon-hex form (e.g. `a9fe:a9fe`) into the equivalent dotted-quad IPv4 + * representation (`169.254.169.254`). Returns null for anything else. + */ +function hexPairToDottedIpv4(tail: string): string | null { + const match = tail.match(/^([0-9a-f]{1,4}):([0-9a-f]{1,4})$/); + if (!match) return null; + const hi = Number.parseInt(match[1], 16); + const lo = Number.parseInt(match[2], 16); + if (!Number.isFinite(hi) || !Number.isFinite(lo)) return null; + if (hi < 0 || hi > 0xffff || lo < 0 || lo > 0xffff) return null; + const a = (hi >> 8) & 0xff; + const b = hi & 0xff; + const c = (lo >> 8) & 0xff; + const d = lo & 0xff; + return `${a}.${b}.${c}.${d}`; +} + +function safeHostname(rawUrl: string): string | null { + try { + return new URL(rawUrl).hostname.toLowerCase(); + } catch { + return null; + } +} diff --git a/packages/appkit/src/plugins/agents/tools/sql-policy.ts b/packages/appkit/src/plugins/agents/tools/sql-policy.ts new file mode 100644 index 00000000..6f889d44 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/sql-policy.ts @@ -0,0 +1,317 @@ +/** + * Conservative SQL classifier used by agent-facing query tools to enforce + * `readOnly: true` annotations at execution time. + * + * Why a hand-rolled tokenizer rather than `node-sql-parser` or `pgsql-parser`: + * + * - `node-sql-parser`'s Hive/Spark dialect coverage rejects common Databricks + * SQL patterns (three-part `catalog.schema.table` names, `SHOW TABLES IN`, + * `DESCRIBE EXTENDED`, `EXPLAIN`) that must be allowed by a read-only + * classifier. Its PostgreSQL grammar rejects `SHOW`/`DESCRIBE` too. + * - `pgsql-parser` (libpg_query) is a native binding and fails to install + * cleanly on every Databricks App runtime we care about. + * + * We don't need to fully parse SQL — we only need to decide whether every + * statement in the batch starts with a read-only keyword. A small tokenizer + * that correctly strips strings, identifiers, and comments is enough and + * costs no extra dependencies. + * + * What this classifier guarantees (when it returns `readOnly: true`): + * + * - Every semicolon-separated statement outside a string, identifier, or + * comment begins with `SELECT`, `WITH`, `SHOW`, `EXPLAIN`, `DESCRIBE`, or + * `DESC`. + * - `SELECT 1; DROP TABLE x` is rejected (stacked write detected). + * - `SELECT 'value; DROP TABLE x'` passes (literal inside a string). + * - `-- DROP TABLE x\nSELECT 1` passes (comment stripped). + * - `SELECT 1 ` passes (comment stripped). + * + * What this classifier does NOT guarantee: + * + * - A `SELECT` statement may still have side effects via function calls + * (`SELECT pg_advisory_lock(...)`, `SELECT lo_import('/etc/passwd')`, CTEs + * with DML in Postgres 9.1+). Callers that need stronger guarantees should + * combine this check with a runtime mechanism: for PostgreSQL, execute the + * statement inside a dedicated client's `BEGIN READ ONLY … ROLLBACK` + * transaction (see `LakebasePlugin.runReadOnlyStatement`). A batched + * `pool.query("BEGIN READ ONLY; ; ROLLBACK")` cannot be used because + * the Postgres Extended Query protocol rejects multi-statement prepared + * queries, which silently breaks parameterized SQL. + */ + +const READ_ONLY_KEYWORDS = new Set([ + "SELECT", + "WITH", + "SHOW", + "EXPLAIN", + "DESCRIBE", + "DESC", +]); + +type SqlReadOnlyResult = + | { readOnly: true; statements: number } + | { readOnly: false; reason: string }; + +/** + * Classify a SQL string as read-only or not. See module docstring for the + * precise guarantee this offers. + */ +export function classifyReadOnly(sql: string): SqlReadOnlyResult { + const strip = stripCommentsAndQuoted(sql); + if (strip.unterminated) { + return { + readOnly: false, + reason: `SQL has an unterminated ${strip.unterminated} literal`, + }; + } + const statements = splitStatements(strip.cleaned); + + if (statements.length === 0) { + return { + readOnly: false, + reason: "SQL is empty or contains only comments", + }; + } + + for (let i = 0; i < statements.length; i++) { + const stmt = statements[i]; + const firstWord = firstKeyword(stmt); + if (!firstWord) { + return { + readOnly: false, + reason: `statement ${i + 1} of ${statements.length} is empty`, + }; + } + if (!READ_ONLY_KEYWORDS.has(firstWord.toUpperCase())) { + return { + readOnly: false, + reason: `statement starts with '${firstWord}'; only SELECT, WITH, SHOW, EXPLAIN, DESCRIBE, DESC are allowed in read-only mode`, + }; + } + } + + return { readOnly: true, statements: statements.length }; +} + +/** + * Assert `sql` is read-only or throw {@link ReadOnlySqlViolation}. Suitable + * for calling from agent-tool handlers where the thrown string surfaces back + * to the LLM as the tool's error output. + */ +export function assertReadOnlySql(sql: string): void { + const result = classifyReadOnly(sql); + if (!result.readOnly) { + throw new ReadOnlySqlViolation(result.reason); + } +} + +export class ReadOnlySqlViolation extends Error { + constructor(reason: string) { + super(`SQL read-only policy violation: ${reason}`); + this.name = "ReadOnlySqlViolation"; + } +} + +// --------------------------------------------------------------------------- +// Tokenizer helpers +// --------------------------------------------------------------------------- + +/** + * Walk `sql` character-by-character and replace every string literal, + * identifier quote, and comment body with a single space of equivalent + * length. Leaves structural tokens (semicolons, whitespace, identifiers, + * operators) in place. + * + * Handles: + * - `-- line comments` through end-of-line + * - SQL block comments (slash-star ... star-slash) with correct nesting (PostgreSQL) + * - `'single-quoted strings'` with `''` escape + * - `"double-quoted identifiers"` with `""` escape (ANSI) + * - `` `backtick identifiers` `` (Databricks) + * - `$tag$dollar quoted$tag$` strings (PostgreSQL) + * - `E'escape-style'` strings (PostgreSQL) + */ +type StripResult = { + cleaned: string; + /** Non-null if tokenization ended inside an unterminated literal or comment. */ + unterminated: + | null + | "string" + | "identifier" + | "block comment" + | "dollar-quoted string"; +}; + +function stripCommentsAndQuoted(sql: string): StripResult { + const out: string[] = []; + let i = 0; + const n = sql.length; + let unterminated: StripResult["unterminated"] = null; + + while (i < n) { + const ch = sql[i]; + const next = i + 1 < n ? sql[i + 1] : ""; + + if (ch === "-" && next === "-") { + out.push(" "); + i += 2; + while (i < n && sql[i] !== "\n") { + out.push(" "); + i++; + } + continue; + } + + if (ch === "/" && next === "*") { + out.push(" "); + i += 2; + let depth = 1; + while (i < n && depth > 0) { + if (sql[i] === "/" && sql[i + 1] === "*") { + out.push(" "); + i += 2; + depth++; + continue; + } + if (sql[i] === "*" && sql[i + 1] === "/") { + out.push(" "); + i += 2; + depth--; + continue; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (depth > 0) { + unterminated = "block comment"; + } + continue; + } + + if ( + ch === "'" || + (ch === "E" && next === "'") || + (ch === "e" && next === "'") + ) { + if (ch === "E" || ch === "e") { + out.push(" "); + i++; + } + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === "'" && sql[i + 1] === "'") { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "\\" && sql[i + 1]) { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "'") { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "string"; + continue; + } + + if (ch === '"') { + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === '"' && sql[i + 1] === '"') { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === '"') { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "identifier"; + continue; + } + + if (ch === "`") { + out.push(" "); + i++; + let closed = false; + while (i < n) { + if (sql[i] === "`" && sql[i + 1] === "`") { + out.push(" "); + i += 2; + continue; + } + if (sql[i] === "`") { + out.push(" "); + i++; + closed = true; + break; + } + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + if (!closed) unterminated = "identifier"; + continue; + } + + if (ch === "$") { + const tagMatch = sql.slice(i).match(/^\$([A-Za-z_][A-Za-z0-9_]*)?\$/); + if (tagMatch) { + const tag = tagMatch[0]; + out.push(" ".repeat(tag.length)); + i += tag.length; + const closeIdx = sql.indexOf(tag, i); + if (closeIdx === -1) { + while (i < n) { + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + unterminated = "dollar-quoted string"; + } else { + while (i < closeIdx) { + out.push(sql[i] === "\n" ? "\n" : " "); + i++; + } + out.push(" ".repeat(tag.length)); + i += tag.length; + } + continue; + } + } + + out.push(ch); + i++; + } + + return { cleaned: out.join(""), unterminated }; +} + +/** Split on unquoted `;`, trim, drop empty segments. */ +function splitStatements(cleanedSql: string): string[] { + return cleanedSql + .split(";") + .map((s) => s.trim()) + .filter((s) => s.length > 0); +} + +/** Return the first bareword keyword of a statement, or null if empty. */ +function firstKeyword(stmt: string): string | null { + const match = stmt.match(/^\s*([A-Za-z_][A-Za-z0-9_]*)/); + return match ? match[1] : null; +} diff --git a/packages/appkit/src/plugins/agents/tools/tool.ts b/packages/appkit/src/plugins/agents/tools/tool.ts new file mode 100644 index 00000000..b5d4db65 --- /dev/null +++ b/packages/appkit/src/plugins/agents/tools/tool.ts @@ -0,0 +1,53 @@ +import type { z } from "zod"; +import type { FunctionTool } from "./function-tool"; +import { toToolJSONSchema } from "./json-schema"; + +export interface ToolConfig { + name: string; + description?: string; + schema: S; + execute: (args: z.infer) => Promise | string; +} + +/** + * Factory for defining function tools with Zod schemas. + * + * - Generates JSON Schema (for the LLM) from the Zod schema via `z.toJSONSchema()`. + * - Infers the `execute` argument type from the schema. + * - Validates tool call arguments at runtime. On validation failure, returns + * a formatted error string to the LLM instead of throwing, so the model + * can self-correct on its next turn. + */ +export function tool(config: ToolConfig): FunctionTool { + const parameters = toToolJSONSchema(config.schema) as unknown as Record< + string, + unknown + >; + + return { + type: "function", + name: config.name, + description: config.description ?? config.name, + parameters, + execute: async (args: Record) => { + const parsed = config.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, config.name); + } + return config.execute(parsed.data as z.infer); + }, + }; +} + +/** + * Formats a Zod validation error into an LLM-friendly string. + * + * Example: `Invalid arguments for get_weather: city: Invalid input: expected string, received undefined` + */ +export function formatZodError(error: z.ZodError, toolName: string): string { + const parts = error.issues.map((issue) => { + const field = issue.path.length > 0 ? issue.path.join(".") : "(root)"; + return `${field}: ${issue.message}`; + }); + return `Invalid arguments for ${toolName}: ${parts.join("; ")}`; +} diff --git a/packages/appkit/src/plugins/agents/types.ts b/packages/appkit/src/plugins/agents/types.ts new file mode 100644 index 00000000..086a0426 --- /dev/null +++ b/packages/appkit/src/plugins/agents/types.ts @@ -0,0 +1,54 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import type { FunctionTool } from "./tools/function-tool"; +import type { HostedTool } from "./tools/hosted-tools"; + +/** + * A tool reference produced by a plugin's `.toolkit()` call. The agents plugin + * recognizes the `__toolkitRef` brand and dispatches tool invocations through + * `PluginContext.executeTool(req, pluginName, localName, ...)`, preserving + * OBO (asUser) and telemetry spans. + */ +export interface ToolkitEntry { + readonly __toolkitRef: true; + pluginName: string; + localName: string; + def: AgentToolDefinition; + annotations?: ToolAnnotations; + /** + * Whether this tool is eligible for `autoInheritTools` spreading. Mirrors + * {@link ToolEntry.autoInheritable} from the source registry so the agents + * plugin can filter auto-inherited tools without re-walking the provider's + * internal registry. + */ + autoInheritable?: boolean; +} + +/** + * Any tool an agent can invoke: inline function tools (`tool()`), hosted MCP + * tools (`mcpServer()` / raw hosted), or toolkit references from plugins + * (`analytics().toolkit()`). + */ +export type AgentTool = FunctionTool | HostedTool | ToolkitEntry; + +export interface ToolkitOptions { + /** Key prefix to prepend to each tool's local name. Defaults to `${pluginName}.`. */ + prefix?: string; + /** Only include tools whose local name matches one of these. */ + only?: string[]; + /** Exclude tools whose local name matches one of these. */ + except?: string[]; + /** Remap specific local names to different keys (applied after prefix). */ + rename?: Record; +} + +/** + * Type guard for `ToolkitEntry` — used to differentiate toolkit references + * from inline tools in a mixed `tools` record. + */ +export function isToolkitEntry(value: unknown): value is ToolkitEntry { + return ( + typeof value === "object" && + value !== null && + (value as { __toolkitRef?: unknown }).__toolkitRef === true + ); +} diff --git a/packages/appkit/src/plugins/analytics/analytics.ts b/packages/appkit/src/plugins/analytics/analytics.ts index a9c688da..78d11d4d 100644 --- a/packages/appkit/src/plugins/analytics/analytics.ts +++ b/packages/appkit/src/plugins/analytics/analytics.ts @@ -1,16 +1,26 @@ import type { WorkspaceClient } from "@databricks/sdk-experimental"; import type express from "express"; import type { + AgentToolDefinition, IAppRouter, PluginExecuteConfig, SQLTypeMarker, StreamExecutionSettings, + ToolProvider, } from "shared"; +import { z } from "zod"; import { SQLWarehouseConnector } from "../../connectors"; import { getWarehouseId, getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; +import { assertReadOnlySql } from "../agents/tools/sql-policy"; import { queryDefaults } from "./defaults"; import manifest from "./manifest.json"; import { QueryProcessor } from "./query"; @@ -22,7 +32,7 @@ import type { const logger = createLogger("analytics"); -export class AnalyticsPlugin extends Plugin { +export class AnalyticsPlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"analytics">; @@ -262,6 +272,52 @@ export class AnalyticsPlugin extends Plugin { this.streamManager.abortAll(); } + private tools = { + query: defineTool({ + description: + "Execute a read-only SQL query against the Databricks SQL warehouse. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted; writes are rejected. Returns the query results as JSON.", + schema: z.object({ + query: z + .string() + .describe( + "The SQL query to execute. Must be a SELECT, WITH, SHOW, EXPLAIN, or DESCRIBE statement.", + ), + }), + annotations: { + readOnly: true, + requiresUserContext: true, + }, + autoInheritable: true, + handler: (args, signal) => { + assertReadOnlySql(args.query); + return this.query(args.query, undefined, undefined, signal); + }, + }), + }; + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + /** + * Returns the plugin's tools as a keyed record of `ToolkitEntry` markers. + * Called by the agents plugin (via `resolveToolkitFromProvider`) to spread + * a filtered, renamed view of the plugin's tools into an agent's tool + * index. Most callers should go through `fromPlugin(analytics, opts)` at + * module scope instead of reaching for this directly. + */ + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + /** * Returns the public exports for the analytics plugin. * Note: `asUser()` is automatically added by AppKit. diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts new file mode 100644 index 00000000..42c9b516 --- /dev/null +++ b/packages/appkit/src/plugins/analytics/tests/analytics.readonly.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +import { AnalyticsPlugin } from "../analytics"; + +/** + * Tests the read-only SQL enforcement on the analytics agent tool. + * + * The tool is annotated `{ readOnly: true, requiresUserContext: true }`; this + * suite verifies that the annotation is enforced at execution time — not just + * exposed as metadata to the LLM — by the `assertReadOnlySql` guard in the + * tool's handler. + */ + +function makePlugin(): AnalyticsPlugin { + return new AnalyticsPlugin({}); +} + +describe("AnalyticsPlugin.query agent tool — readOnly annotation", () => { + test("is advertised with readOnly:true and requiresUserContext:true", () => { + const plugin = makePlugin(); + const defs = plugin.getAgentTools(); + const query = defs.find((d) => d.name === "query"); + expect(query).toBeDefined(); + expect(query?.annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); +}); + +describe("AnalyticsPlugin.query agent tool — runtime enforcement", () => { + test("rejects a DROP statement before it reaches this.query", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await expect( + plugin.executeAgentTool("query", { query: "DROP TABLE users" }), + ).rejects.toThrow(/read-only policy violation/i); + expect(spy).not.toHaveBeenCalled(); + }); + + test("rejects UPDATE, DELETE, INSERT, TRUNCATE, GRANT", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + for (const q of [ + "UPDATE users SET email='x'", + "DELETE FROM orders", + "INSERT INTO x VALUES (1)", + "TRUNCATE TABLE orders", + "GRANT SELECT ON t TO u", + ]) { + await expect( + plugin.executeAgentTool("query", { query: q }), + ).rejects.toThrow(/read-only policy violation/i); + } + expect(spy).not.toHaveBeenCalled(); + }); + + test("rejects a stacked SELECT + DROP", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await expect( + plugin.executeAgentTool("query", { + query: "SELECT 1; DROP TABLE users", + }), + ).rejects.toThrow(/DROP/); + expect(spy).not.toHaveBeenCalled(); + }); + + test("passes a plain SELECT through to this.query", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [{ id: 1 }] } as any); + const result = await plugin.executeAgentTool("query", { + query: "SELECT * FROM main.sales.orders", + }); + expect(result).toEqual({ rows: [{ id: 1 }] }); + expect(spy).toHaveBeenCalledWith( + "SELECT * FROM main.sales.orders", + undefined, + undefined, + undefined, + ); + }); + + test("passes WITH … SELECT through", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await plugin.executeAgentTool("query", { + query: "WITH a AS (SELECT 1) SELECT * FROM a", + }); + expect(spy).toHaveBeenCalledOnce(); + }); + + test("passes SHOW TABLES through", async () => { + const plugin = makePlugin(); + const spy = vi + .spyOn(plugin, "query") + // biome-ignore lint/suspicious/noExplicitAny: mocked return + .mockResolvedValue({ rows: [] } as any); + await plugin.executeAgentTool("query", { + query: "SHOW TABLES IN main.sales", + }); + expect(spy).toHaveBeenCalledOnce(); + }); +}); diff --git a/packages/appkit/src/plugins/analytics/tests/analytics.test.ts b/packages/appkit/src/plugins/analytics/tests/analytics.test.ts index 9a30440e..29157fff 100644 --- a/packages/appkit/src/plugins/analytics/tests/analytics.test.ts +++ b/packages/appkit/src/plugins/analytics/tests/analytics.test.ts @@ -608,4 +608,22 @@ describe("Analytics Plugin", () => { }); }); }); + + describe("toolkit()", () => { + test("produces ToolkitEntry records keyed by the plugin name", () => { + const plugin = new AnalyticsPlugin({ name: "analytics" }); + const entries = plugin.toolkit(); + expect(Object.keys(entries)).toContain("analytics.query"); + const entry = entries["analytics.query"]; + expect(entry.__toolkitRef).toBe(true); + expect(entry.pluginName).toBe("analytics"); + expect(entry.localName).toBe("query"); + }); + + test("respects prefix and only options", () => { + const plugin = new AnalyticsPlugin({ name: "analytics" }); + const entries = plugin.toolkit({ prefix: "", only: ["query"] }); + expect(Object.keys(entries)).toEqual(["query"]); + }); + }); }); diff --git a/packages/appkit/src/plugins/files/plugin.ts b/packages/appkit/src/plugins/files/plugin.ts index 75f2e14d..fc768c5a 100644 --- a/packages/appkit/src/plugins/files/plugin.ts +++ b/packages/appkit/src/plugins/files/plugin.ts @@ -2,19 +2,37 @@ import { STATUS_CODES } from "node:http"; import { Readable } from "node:stream"; import { ApiError } from "@databricks/sdk-experimental"; import type express from "express"; -import type { IAppRouter, PluginExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + PluginExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { contentTypeFromPath, FilesConnector, isSafeInlineContentType, validateCustomContentTypes, } from "../../connectors/files"; -import { getCurrentUserId, getWorkspaceClient } from "../../context"; +import { + getCurrentUserId, + getExecutionContext, + getWorkspaceClient, +} from "../../context"; +import { isUserContext } from "../../context/user-context"; import { AuthenticationError } from "../../errors"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest, ResourceRequirement } from "../../registry"; import { ResourceType } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; import { FILES_DOWNLOAD_DEFAULTS, FILES_MAX_UPLOAD_SIZE, @@ -41,7 +59,7 @@ import type { const logger = createLogger("files"); -export class FilesPlugin extends Plugin { +export class FilesPlugin extends Plugin implements ToolProvider { name = "files"; /** Plugin manifest declaring metadata and resource requirements. */ @@ -52,6 +70,7 @@ export class FilesPlugin extends Plugin { private volumeConnectors: Record = {}; private volumeConfigs: Record = {}; private volumeKeys: string[] = []; + private tools: ToolRegistry = {}; /** * Scans `process.env` for `DATABRICKS_VOLUME_*` keys and merges them with @@ -226,6 +245,10 @@ export class FilesPlugin extends Plugin { }); } + for (const volumeKey of this.volumeKeys) { + Object.assign(this.tools, this._defineVolumeTools(volumeKey)); + } + // Warn at startup for volumes without an explicit policy for (const key of this.volumeKeys) { if (!volumes[key].policy) { @@ -1019,6 +1042,91 @@ export class FilesPlugin extends Plugin { }; } + /** + * Builds the agent-tool registry entries for a single volume. One set of + * tools per configured volume, keyed by `${volumeKey}.${method}`. + * + * Each handler resolves the caller's identity from the current execution + * context (OBO user when the agent run is wrapped in `asUser(req)`, service + * principal otherwise in local dev) and dispatches through + * `createVolumeAPI(volumeKey, user)` so the volume's policy is enforced + * uniformly for agent and HTTP callers. + */ + private _defineVolumeTools(volumeKey: string): ToolRegistry { + const buildUser = (): FilePolicyUser => { + const ctx = getExecutionContext(); + return isUserContext(ctx) + ? { id: ctx.userId } + : { id: ctx.serviceUserId, isServicePrincipal: true }; + }; + const api = () => this.createVolumeAPI(volumeKey, buildUser()); + return { + [`${volumeKey}.list`]: defineTool({ + description: `List files and directories in the "${volumeKey}" volume`, + schema: z.object({ + path: z + .string() + .optional() + .describe("Directory path to list (optional, defaults to root)"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().list(args.path), + }), + [`${volumeKey}.read`]: defineTool({ + description: `Read a text file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to read"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().read(args.path), + }), + [`${volumeKey}.exists`]: defineTool({ + description: `Check if a file or directory exists in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Path to check"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().exists(args.path), + }), + [`${volumeKey}.metadata`]: defineTool({ + description: `Get metadata (size, type, last modified) for a file in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => api().metadata(args.path), + }), + [`${volumeKey}.upload`]: defineTool({ + description: `Upload a text file to the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Destination file path"), + contents: z.string().describe("File contents as a string"), + overwrite: z + .boolean() + .optional() + .describe("Whether to overwrite existing file"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => + api().upload(args.path, args.contents, { + overwrite: args.overwrite, + }), + }), + [`${volumeKey}.delete`]: defineTool({ + description: `Delete a file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to delete"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => api().delete(args.path), + }), + }; + } + private inflightWrites = 0; private trackWrite(fn: () => Promise): Promise { @@ -1047,6 +1155,22 @@ export class FilesPlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + /** * Returns the programmatic API for the Files plugin. * Callable with a volume key to get a volume-scoped handle. diff --git a/packages/appkit/src/plugins/files/tests/plugin.test.ts b/packages/appkit/src/plugins/files/tests/plugin.test.ts index a4b9bea2..bbaa1b98 100644 --- a/packages/appkit/src/plugins/files/tests/plugin.test.ts +++ b/packages/appkit/src/plugins/files/tests/plugin.test.ts @@ -205,6 +205,62 @@ describe("FilesPlugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per volume", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const tools = plugin.getAgentTools(); + const names = tools.map((t) => t.name); + + expect(names).toContain("uploads.list"); + expect(names).toContain("uploads.read"); + expect(names).toContain("uploads.exists"); + expect(names).toContain("uploads.metadata"); + expect(names).toContain("uploads.upload"); + expect(names).toContain("uploads.delete"); + + expect(names).toContain("exports.list"); + expect(names).toContain("exports.read"); + expect(names).toContain("exports.delete"); + + expect(tools).toHaveLength(12); + }); + + test("dispatches to the correct volume API based on the tool name", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const asyncIterable = (items: { path: string }[]) => ({ + [Symbol.asyncIterator]: async function* () { + for (const item of items) yield item; + }, + }); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "uploads-file" }]), + ); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "exports-file" }]), + ); + + const uploadsResult = (await plugin.executeAgentTool( + "uploads.list", + {}, + )) as { path: string }[]; + const exportsResult = (await plugin.executeAgentTool( + "exports.list", + {}, + )) as { path: string }[]; + + expect(uploadsResult[0].path).toBe("uploads-file"); + expect(exportsResult[0].path).toBe("exports-file"); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const result = await plugin.executeAgentTool("uploads.read", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for uploads.read"); + expect(result).toContain("path"); + }); + }); + describe("exports()", () => { test("returns a callable function with a .volume alias", () => { const plugin = new FilesPlugin(VOLUMES_CONFIG); diff --git a/packages/appkit/src/plugins/genie/genie.ts b/packages/appkit/src/plugins/genie/genie.ts index 712aadbf..3167794e 100644 --- a/packages/appkit/src/plugins/genie/genie.ts +++ b/packages/appkit/src/plugins/genie/genie.ts @@ -1,11 +1,24 @@ import { randomUUID } from "node:crypto"; import type express from "express"; -import type { IAppRouter, StreamExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + StreamExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { GenieConnector } from "../../connectors"; import { getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; import { genieStreamDefaults } from "./defaults"; import manifest from "./manifest.json"; import type { @@ -17,7 +30,7 @@ import type { const logger = createLogger("genie"); -export class GeniePlugin extends Plugin { +export class GeniePlugin extends Plugin implements ToolProvider { static manifest = manifest as PluginManifest<"genie">; protected static description = @@ -25,6 +38,7 @@ export class GeniePlugin extends Plugin { protected declare config: IGenieConfig; private readonly genieConnector: GenieConnector; + private tools: ToolRegistry = {}; constructor(config: IGenieConfig) { super(config); @@ -36,6 +50,54 @@ export class GeniePlugin extends Plugin { timeout: this.config.timeout, maxMessages: 200, }); + + for (const alias of Object.keys(this.config.spaces ?? {})) { + Object.assign(this.tools, this._defineSpaceTools(alias)); + } + } + + /** + * Builds the registry entries for a single Genie space alias. + * One set of tools per configured space, keyed by `${alias}.${method}`. + */ + private _defineSpaceTools(alias: string): ToolRegistry { + return { + [`${alias}.sendMessage`]: defineTool({ + description: `Send a natural language question to the Genie space "${alias}" and get data analysis results`, + schema: z.object({ + content: z.string().describe("The natural language question to ask"), + conversationId: z + .string() + .optional() + .describe( + "Optional conversation ID to continue an existing conversation", + ), + }), + annotations: { requiresUserContext: true }, + handler: async (args) => { + const events: GenieStreamEvent[] = []; + for await (const event of this.sendMessage( + alias, + args.content, + args.conversationId, + )) { + events.push(event); + } + return events; + }, + }), + [`${alias}.getConversation`]: defineTool({ + description: `Retrieve the conversation history from the Genie space "${alias}"`, + schema: z.object({ + conversationId: z + .string() + .describe("The conversation ID to retrieve"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + autoInheritable: true, + handler: (args) => this.getConversation(alias, args.conversationId), + }), + }; } private defaultSpaces(): Record { @@ -287,6 +349,22 @@ export class GeniePlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + exports() { return { sendMessage: this.sendMessage, diff --git a/packages/appkit/src/plugins/genie/tests/genie.test.ts b/packages/appkit/src/plugins/genie/tests/genie.test.ts index 3cf0784d..672e6242 100644 --- a/packages/appkit/src/plugins/genie/tests/genie.test.ts +++ b/packages/appkit/src/plugins/genie/tests/genie.test.ts @@ -187,6 +187,30 @@ describe("Genie Plugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per configured space", () => { + const plugin = new GeniePlugin(config); + const names = plugin.getAgentTools().map((t) => t.name); + + expect(names).toContain("myspace.sendMessage"); + expect(names).toContain("myspace.getConversation"); + expect(names).toContain("salesbot.sendMessage"); + expect(names).toContain("salesbot.getConversation"); + expect(names).toHaveLength(4); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new GeniePlugin(config); + const result = await plugin.executeAgentTool( + "myspace.getConversation", + {}, + ); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for myspace.getConversation"); + expect(result).toContain("conversationId"); + }); + }); + describe("space alias resolution", () => { test("should return 404 for unknown alias", async () => { const plugin = new GeniePlugin(config); diff --git a/packages/appkit/src/plugins/lakebase/lakebase.ts b/packages/appkit/src/plugins/lakebase/lakebase.ts index 3071d539..aaf61b51 100644 --- a/packages/appkit/src/plugins/lakebase/lakebase.ts +++ b/packages/appkit/src/plugins/lakebase/lakebase.ts @@ -1,4 +1,6 @@ import type { Pool, QueryResult, QueryResultRow } from "pg"; +import type { AgentToolDefinition, ToolProvider } from "shared"; +import { z } from "zod"; import { createLakebasePool, getLakebaseOrmConfig, @@ -8,6 +10,13 @@ import { import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { buildToolkitEntries } from "../agents/build-toolkit"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agents/tools/define-tool"; +import { assertReadOnlySql } from "../agents/tools/sql-policy"; import manifest from "./manifest.json"; import type { ILakebaseConfig } from "./types"; @@ -30,18 +39,13 @@ const logger = createLogger("lakebase"); * const result = await AppKit.lakebase.query("SELECT * FROM users WHERE id = $1", [userId]); * ``` */ -class LakebasePlugin extends Plugin { +export class LakebasePlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"lakebase">; protected declare config: ILakebaseConfig; private pool: Pool | null = null; - constructor(config: ILakebaseConfig) { - super(config); - this.config = config; - } - /** * Initializes the Lakebase connection pool. * Called automatically by AppKit during the plugin setup phase. @@ -79,6 +83,39 @@ class LakebasePlugin extends Plugin { return this.pool!.query(text, values); } + /** + * Execute a single statement inside a `BEGIN READ ONLY … ROLLBACK` + * transaction on a dedicated client. + * + * The three commands MUST share a connection — a naive + * `pool.query("BEGIN READ ONLY; ; ROLLBACK")` batch cannot accept + * parameter values (PostgreSQL's Extended Query protocol rejects multi- + * statement prepared queries), which would silently break every + * parameterized query the agent tool issues. + * + * Returns the raw `rows` array for the user's statement. Side effects the + * statement may attempt (writes, writable-function side effects) are + * rejected by PostgreSQL under the read-only transaction posture. + */ + private async runReadOnlyStatement( + text: string, + values?: unknown[], + ): Promise { + // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup() + const client = await this.pool!.connect(); + try { + await client.query("BEGIN READ ONLY"); + const result = await client.query(text, values); + return result.rows; + } finally { + try { + await client.query("ROLLBACK"); + } finally { + client.release(); + } + } + } + /** * Gracefully drains and closes the connection pool. * Called automatically by AppKit during shutdown. @@ -102,6 +139,82 @@ class LakebasePlugin extends Plugin { * - `getOrmConfig()` — Returns a config object compatible with Drizzle, TypeORM, Sequelize, etc. * - `getPgConfig()` — Returns a `pg.PoolConfig` object for manual pool construction */ + + /** + * Agent tool registry. Empty by default — the Lakebase plugin does NOT + * expose its SQL connection to LLM agents unless the developer explicitly + * opts in via `config.exposeAsAgentTool`. See {@link buildQueryTool}. + */ + private tools: Record> = {}; + + constructor(config: ILakebaseConfig) { + super(config); + this.config = config; + if (config.exposeAsAgentTool) { + if (config.exposeAsAgentTool.iUnderstandRunsAsServicePrincipal !== true) { + throw new Error( + "lakebase.exposeAsAgentTool requires iUnderstandRunsAsServicePrincipal: true — this acknowledges that SQL statements authored by the LLM run with the application's service-principal credentials regardless of which end user initiated the request.", + ); + } + this.tools = { query: this.buildQueryTool(config.exposeAsAgentTool) }; + logger.warn( + "Lakebase agent tool is enabled (readOnly=%s). Every agent with access to this plugin can execute SQL against the Lakebase database as the service principal.", + config.exposeAsAgentTool.readOnly !== false, + ); + } + } + + private buildQueryTool( + opt: NonNullable, + ) { + const readOnly = opt.readOnly !== false; + return defineTool({ + description: readOnly + ? "Execute a read-only SQL query against the Lakebase PostgreSQL database. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal." + : "Execute a parameterized SQL statement against the Lakebase PostgreSQL database. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal. This tool can modify data; every invocation requires explicit human approval.", + schema: z.object({ + text: z + .string() + .describe( + "SQL statement with $1, $2, ... placeholders for parameters", + ), + values: z + .array(z.unknown()) + .optional() + .describe("Parameter values corresponding to placeholders"), + }), + annotations: { + readOnly, + destructive: !readOnly, + idempotent: false, + }, + handler: async (args) => { + if (readOnly) { + assertReadOnlySql(args.text); + return this.runReadOnlyStatement(args.text, args.values); + } + const result = await this.query(args.text, args.values); + return result.rows; + }, + }); + } + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + + toolkit(opts?: import("../agents/types").ToolkitOptions) { + return buildToolkitEntries(this.name, this.tools, opts); + } + exports() { return { // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup(), which AppKit always awaits before exposing the plugin API diff --git a/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts new file mode 100644 index 00000000..8e59fb32 --- /dev/null +++ b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts @@ -0,0 +1,238 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +/** + * Tests the agent-tool surface of the Lakebase plugin. + * + * The plugin defaults to **not** exposing an agent tool at all. Enabling the + * tool is an explicit opt-in (`exposeAsAgentTool` with an acknowledgement + * flag) because every invocation runs with the application's service- + * principal credentials regardless of which end user initiated the request. + */ + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +// Client calls recorded by the read-only-statement test. The `connect()` +// mock returns a fresh client whose `query` pushes to this array so tests +// can assert the exact sequence of statements emitted on the dedicated +// connection. +const clientQueries: Array<{ text: string; values?: unknown[] }> = []; +const clientReleases: number[] = []; + +vi.mock("../../../connectors/lakebase", () => ({ + createLakebasePool: vi.fn(() => ({ + query: vi.fn(), + connect: vi.fn(async () => { + let releaseCalls = 0; + return { + query: vi.fn(async (text: string, values?: unknown[]) => { + clientQueries.push({ text, values }); + return { rows: [{ n: 1 }] }; + }), + release: vi.fn(() => { + releaseCalls += 1; + clientReleases.push(releaseCalls); + }), + }; + }), + end: vi.fn(), + })), + getLakebaseOrmConfig: vi.fn(() => ({})), + getLakebasePgConfig: vi.fn(() => ({})), + getUsernameWithApiLookup: vi.fn(async () => "test-user"), +})); + +import { LakebasePlugin } from "../lakebase"; + +function makePlugin( + config: ConstructorParameters[0], +): LakebasePlugin { + return new LakebasePlugin(config); +} + +describe("LakebasePlugin — agent tool opt-in", () => { + test("does not register an agent tool by default", () => { + const plugin = makePlugin({}); + expect(plugin.getAgentTools()).toEqual([]); + }); + + test("does not register a tool when `pool` is set but `exposeAsAgentTool` is absent", () => { + const plugin = makePlugin({ pool: {} }); + expect(plugin.getAgentTools()).toEqual([]); + }); + + test("throws when exposeAsAgentTool is set without the acknowledgement flag", () => { + expect(() => + makePlugin({ + exposeAsAgentTool: + // biome-ignore lint/suspicious/noExplicitAny: intentionally bypass the required flag for the negative case + {} as any, + }), + ).toThrow(/iUnderstandRunsAsServicePrincipal/); + }); + + test("registers a read-only tool when opted in with defaults", () => { + const plugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + const defs = plugin.getAgentTools(); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].annotations).toEqual({ + readOnly: true, + destructive: false, + idempotent: false, + }); + }); + + test("registers a destructive tool when readOnly: false is explicit", () => { + const plugin = makePlugin({ + exposeAsAgentTool: { + iUnderstandRunsAsServicePrincipal: true, + readOnly: false, + }, + }); + const defs = plugin.getAgentTools(); + expect(defs[0].annotations).toEqual({ + readOnly: false, + destructive: true, + idempotent: false, + }); + }); +}); + +describe("LakebasePlugin — readOnly enforcement", () => { + let plugin: LakebasePlugin; + + beforeEach(async () => { + clientQueries.length = 0; + clientReleases.length = 0; + plugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + await plugin.setup(); + }); + + test("rejects DROP before acquiring a client", async () => { + await expect( + plugin.executeAgentTool("query", { text: "DROP TABLE users" }), + ).rejects.toThrow(/read-only policy violation/i); + expect(clientQueries).toHaveLength(0); + }); + + test("rejects UPDATE, DELETE, INSERT", async () => { + for (const text of [ + "UPDATE users SET email='x'", + "DELETE FROM orders", + "INSERT INTO x VALUES (1)", + ]) { + await expect(plugin.executeAgentTool("query", { text })).rejects.toThrow( + /read-only policy violation/i, + ); + } + expect(clientQueries).toHaveLength(0); + }); + + test("runs SELECT inside BEGIN READ ONLY / ROLLBACK on a dedicated client", async () => { + const rows = await plugin.executeAgentTool("query", { + text: "SELECT * FROM users", + }); + expect(rows).toEqual([{ n: 1 }]); + expect(clientQueries.map((c) => c.text)).toEqual([ + "BEGIN READ ONLY", + "SELECT * FROM users", + "ROLLBACK", + ]); + // Client must be released exactly once, regardless of outcome. + expect(clientReleases).toHaveLength(1); + }); + + test("forwards parameter values to the user statement only (the regression fix)", async () => { + // Prior to the fix this would have failed with "cannot insert multiple + // commands into a prepared statement" because pg's Extended Query + // protocol rejects multi-statement batches when values are supplied. + await plugin.executeAgentTool("query", { + text: "SELECT * FROM users WHERE id = $1", + values: [42], + }); + expect(clientQueries).toEqual([ + { text: "BEGIN READ ONLY", values: undefined }, + { text: "SELECT * FROM users WHERE id = $1", values: [42] }, + { text: "ROLLBACK", values: undefined }, + ]); + }); + + test("releases the client even when the user statement throws", async () => { + // Poison the client so the middle query throws (simulates a Postgres + // error like "cannot execute UPDATE in a read-only transaction"). + const { createLakebasePool } = await import("../../../connectors/lakebase"); + const connect = vi.fn(async () => ({ + query: vi + .fn() + .mockResolvedValueOnce({ rows: [] }) + .mockRejectedValueOnce(new Error("read-only violation")) + .mockResolvedValueOnce({ rows: [] }), + release: vi.fn(() => { + clientReleases.push(clientReleases.length + 1); + }), + })); + // biome-ignore lint/suspicious/noExplicitAny: test override + ( + createLakebasePool as unknown as { mockReturnValueOnce: any } + ).mockReturnValueOnce({ query: vi.fn(), connect, end: vi.fn() }); + + clientQueries.length = 0; + clientReleases.length = 0; + const leakyPlugin = makePlugin({ + exposeAsAgentTool: { iUnderstandRunsAsServicePrincipal: true }, + }); + await leakyPlugin.setup(); + + await expect( + leakyPlugin.executeAgentTool("query", { + text: "SELECT * FROM users", + }), + ).rejects.toThrow(/read-only violation/); + expect(clientReleases).toHaveLength(1); + }); +}); + +describe("LakebasePlugin — destructive mode", () => { + test("does NOT wrap in read-only transaction when readOnly: false", async () => { + const queryMock = vi.fn((_text: string, _values?: unknown[]) => + Promise.resolve({ rows: [] }), + ); + const plugin = makePlugin({ + exposeAsAgentTool: { + iUnderstandRunsAsServicePrincipal: true, + readOnly: false, + }, + }); + await plugin.setup(); + vi.spyOn(plugin, "query").mockImplementation(async (text, values) => { + queryMock(text, values); + return { rows: [] } as never; + }); + + await plugin.executeAgentTool("query", { + text: "UPDATE t SET x=1 WHERE id=$1", + values: [42], + }); + + expect(queryMock).toHaveBeenCalledWith( + "UPDATE t SET x=1 WHERE id=$1", + [42], + ); + }); +}); diff --git a/packages/appkit/src/plugins/lakebase/types.ts b/packages/appkit/src/plugins/lakebase/types.ts index ac6997c6..6703e425 100644 --- a/packages/appkit/src/plugins/lakebase/types.ts +++ b/packages/appkit/src/plugins/lakebase/types.ts @@ -1,6 +1,42 @@ import type { BasePluginConfig } from "shared"; import type { LakebasePoolConfig } from "../../connectors/lakebase"; +/** + * Opt-in configuration for exposing Lakebase as an agent-callable SQL tool. + * + * This tool executes LLM-authored SQL against the Lakebase pool. The pool is + * **always bound to the application's service-principal credentials**, so any + * agent that can call this tool effectively has full SP access to the database + * regardless of which end user initiated the request. Exposing it is a + * deliberate decision the developer must make explicitly — hence the required + * acknowledgement flag. + * + * When `readOnly: true` (default when opted in), every statement is: + * 1. Classified by {@link @databricks/appkit's sql-policy classifier}; anything + * that isn't a pure `SELECT`/`WITH`/`SHOW`/`EXPLAIN`/`DESCRIBE` is rejected. + * 2. Executed inside a `BEGIN READ ONLY … ROLLBACK` transaction so the + * PostgreSQL server rejects writes that slip past the classifier (e.g., a + * `SELECT` over a function with side effects). + * + * When `readOnly: false`, the tool is annotated `destructive: true` and the + * agents plugin will require human approval for every invocation (see + * `AgentsPluginConfig.approval`). + */ +export interface LakebaseExposeAsAgentTool { + /** + * Required acknowledgement that tool invocations run as the service principal + * and share that privilege across end users. Must be set to `true` to opt in. + */ + iUnderstandRunsAsServicePrincipal: true; + /** + * Enforce read-only execution. Defaults to `true`. Set to `false` to allow + * destructive statements — highly discouraged outside of tightly controlled + * single-user deployments. Combined with the `destructive: true` annotation, + * the agents plugin will require explicit human approval for each call. + */ + readOnly?: boolean; +} + /** * Configuration for the Lakebase plugin. * @@ -17,4 +53,11 @@ export interface ILakebaseConfig extends BasePluginConfig { * Common overrides: `max` (pool size), `connectionTimeoutMillis`, `idleTimeoutMillis`. */ pool?: Partial; + /** + * Opt-in to expose Lakebase as an agent-callable SQL tool. By default no + * agent tool is registered — the Lakebase plugin only exposes its API to + * application code. See {@link LakebaseExposeAsAgentTool} for the privilege + * implications of enabling this. + */ + exposeAsAgentTool?: LakebaseExposeAsAgentTool; }