diff --git a/.gitignore b/.gitignore index a1b83bc4f..58cff53ff 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,5 @@ dist/ # IDE .idea/ + +# ahammednibras8 \ No newline at end of file diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 8564212c1..b2026050b 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -7,16 +7,22 @@ import type { CompleteRequestPrompt, CompleteRequestResourceTemplate, CompleteResult, + CreateTaskRequestHandlerExtra, CreateTaskResult, + GetPromptRequest, GetPromptResult, Implementation, + ListPromptsRequest, ListPromptsResult, + ListResourcesRequest, ListResourcesResult, + ListToolsRequest, ListToolsResult, LoggingMessageNotification, Prompt, PromptArgument, PromptReference, + ReadResourceRequest, ReadResourceResult, RequestHandlerExtra, Resource, @@ -31,8 +37,8 @@ import type { ToolExecution, Transport, Variables, - ZodRawShapeCompat -} from '@modelcontextprotocol/core'; + ZodRawShapeCompat, +} from "@modelcontextprotocol/core"; import { assertCompleteRequestPrompt, assertCompleteRequestResourceTemplate, @@ -56,15 +62,47 @@ import { safeParseAsync, toJsonSchemaCompat, UriTemplate, - validateAndWarnToolName -} from '@modelcontextprotocol/core'; -import { ZodOptional } from 'zod'; + validateAndWarnToolName, +} from "@modelcontextprotocol/core"; +import { ZodOptional } from "zod"; -import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; -import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcp-server.js'; -import { getCompleter, isCompletable } from './completable.js'; -import type { ServerOptions } from './server.js'; -import { Server } from './server.js'; +import type { ToolTaskHandler } from "../experimental/tasks/interfaces.js"; +import { ExperimentalMcpServerTasks } from "../experimental/tasks/mcp-server.js"; +import { getCompleter, isCompletable } from "./completable.js"; +import type { ServerOptions } from "./server.js"; +import { Server } from "./server.js"; + +/** + * Context passed to MCP middleware functions. + */ +export interface McpMiddlewareContext { + /** + * The incoming JSON-RPC request. + * While technically mutable, middleware should generally treat this as read-only. + * Mutation is permitted only for specific cases like schema normalization or request enrichment. + */ + request: ServerRequest; + + /** + * Additional metadata passed from the transport or SDK. + */ + extra: RequestHandlerExtra; + + /** + * A generic key-value store for cross-middleware communication (e.g., attaching a user object after auth). + */ + state: Record; +} + +/** + * Middleware function for intercepting MCP requests. + * @param context The request context. + * @param next A function that calls the next middleware or the implementation handler. + */ +export type McpMiddleware = ( + context: McpMiddlewareContext, + next: () => Promise, +) => Promise; /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. @@ -83,6 +121,8 @@ export class McpServer { } = {}; private _registeredTools: { [name: string]: RegisteredTool } = {}; private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + private _middleware: McpMiddleware[] = []; + private _middlewareFrozen = false; private _experimental?: { tasks: ExperimentalMcpServerTasks }; constructor(serverInfo: Implementation, options?: ServerOptions) { @@ -99,18 +139,33 @@ export class McpServer { get experimental(): { tasks: ExperimentalMcpServerTasks } { if (!this._experimental) { this._experimental = { - tasks: new ExperimentalMcpServerTasks(this) + tasks: new ExperimentalMcpServerTasks(this), }; } return this._experimental; } + /** + * Registers a middleware function. + * @param middleware The middleware to register. + */ + public use(middleware: McpMiddleware) { + if (this._middlewareFrozen) { + throw new Error( + "Cannot register middleware after the server has started or processed requests.", + ); + } + this._middleware.push(middleware); + return this; + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * * The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. */ async connect(transport: Transport): Promise { + this._middlewareFrozen = true; return await this.server.connect(transport); } @@ -128,110 +183,196 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(getMethodValue(ListToolsRequestSchema)); - this.server.assertCanSetRequestHandler(getMethodValue(CallToolRequestSchema)); + this.server.assertCanSetRequestHandler( + getMethodValue(ListToolsRequestSchema), + ); + this.server.assertCanSetRequestHandler( + getMethodValue(CallToolRequestSchema), + ); this.server.registerCapabilities({ tools: { - listChanged: true - } + listChanged: true, + }, }); this.server.setRequestHandler( ListToolsRequestSchema, - (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools) - .filter(([, tool]) => tool.enabled) - .map(([name, tool]): Tool => { - const toolDefinition: Tool = { - name, - title: tool.title, - description: tool.description, - inputSchema: (() => { - const obj = normalizeObjectSchema(tool.inputSchema); - return obj - ? (toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'input' - }) as Tool['inputSchema']) - : EMPTY_OBJECT_JSON_SCHEMA; - })(), - annotations: tool.annotations, - execution: tool.execution, - _meta: tool._meta - }; - - if (tool.outputSchema) { - const obj = normalizeObjectSchema(tool.outputSchema); - if (obj) { - toolDefinition.outputSchema = toJsonSchemaCompat(obj, { - strictUnions: true, - pipeStrategy: 'output' - }) as Tool['outputSchema']; - } - } - - return toolDefinition; - }) - }) + ( + request: ListToolsRequest, + extra: RequestHandlerExtra< + ListToolsRequest, + ServerNotification + >, + ) => this._executeRequest< + ListToolsResult, + ListToolsRequest, + RequestHandlerExtra + >( + (): Promise => + Promise.resolve({ + tools: Object.entries(this._registeredTools) + .filter(([, tool]) => tool.enabled) + .map(([name, tool]): Tool => { + const toolDefinition: Tool = { + name, + title: tool.title, + description: tool.description, + inputSchema: (() => { + const obj = normalizeObjectSchema( + tool.inputSchema, + ); + return obj + ? (toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: "input", + }) as Tool["inputSchema"]) + : EMPTY_OBJECT_JSON_SCHEMA; + })(), + annotations: tool.annotations, + execution: tool.execution, + _meta: tool._meta, + }; + + if (tool.outputSchema) { + const obj = normalizeObjectSchema( + tool.outputSchema, + ); + if (obj) { + toolDefinition.outputSchema = + toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: "output", + }) as Tool["outputSchema"]; + } + } + + return toolDefinition; + }), + }), + request, + extra, + ), ); - this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { - try { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); - } - if (!tool.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); - } - - const isTaskRequest = !!request.params.task; - const taskSupport = tool.execution?.taskSupport; - const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); - - // Validate task hint configuration - if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { - throw new McpError( - ErrorCode.InternalError, - `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` - ); - } + this.server.setRequestHandler( + CallToolRequestSchema, + ( + request: CallToolRequest, + extra: RequestHandlerExtra< + CallToolRequest, + ServerNotification + >, + ) => this._executeRequest< + CallToolResult | CreateTaskResult, + CallToolRequest, + RequestHandlerExtra + >( + async ( + request: CallToolRequest, + extra: RequestHandlerExtra< + CallToolRequest, + ServerNotification + >, + ): Promise => { + try { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError( + ErrorCode.InvalidParams, + `Tool ${request.params.name} not found`, + ); + } + if (!tool.enabled) { + throw new McpError( + ErrorCode.InvalidParams, + `Tool ${request.params.name} disabled`, + ); + } - // Handle taskSupport 'required' without task augmentation - if (taskSupport === 'required' && !isTaskRequest) { - throw new McpError( - ErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` - ); - } + const isTaskRequest = !!request.params.task; + const taskSupport = tool.execution?.taskSupport; + const isTaskHandler = "createTask" in + (tool.handler as AnyToolHandler< + ZodRawShapeCompat + >); + + // Validate task hint configuration + if ( + (taskSupport === "required" || + taskSupport === "optional") && + !isTaskHandler + ) { + throw new McpError( + ErrorCode.InternalError, + `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask`, + ); + } - // Handle taskSupport 'optional' without task augmentation - automatic polling - if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, extra); - } + // Handle taskSupport 'required' without task augmentation + if (taskSupport === "required" && !isTaskRequest) { + throw new McpError( + ErrorCode.MethodNotFound, + `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')`, + ); + } - // Normal execution path - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, extra); + // Handle taskSupport 'optional' without task augmentation - automatic polling + if ( + taskSupport === "optional" && !isTaskRequest && + isTaskHandler + ) { + return await this.handleAutomaticTaskPolling( + tool, + request, + extra, + ); + } - // Return CreateTaskResult immediately for task requests - if (isTaskRequest) { - return result; - } + // Normal execution path + const args = await this.validateToolInput( + tool, + request.params.arguments, + request.params.name, + ); + const result = await this.executeToolHandler( + tool, + args, + extra, + ); + + // Return CreateTaskResult immediately for task requests + if (isTaskRequest) { + return result; + } - // Validate output schema for non-task requests - await this.validateToolOutput(tool, result, request.params.name); - return result; - } catch (error) { - if (error instanceof McpError) { - if (error.code === ErrorCode.UrlElicitationRequired) { - throw error; // Return the error to the caller without wrapping in CallToolResult + // Validate output schema for non-task requests + await this.validateToolOutput( + tool, + result, + request.params.name, + ); + return result; + } catch (error) { + if (error instanceof McpError) { + if ( + error.code === + ErrorCode.UrlElicitationRequired + ) { + throw error; + } + } + return this.createToolError( + error instanceof Error + ? error.message + : String(error), + ); } - } - return this.createToolError(error instanceof Error ? error.message : String(error)); - } - }); + }, + request, + extra, + ), + ); this._toolHandlersInitialized = true; } @@ -246,11 +387,11 @@ export class McpServer { return { content: [ { - type: 'text', - text: errorMessage - } + type: "text", + text: errorMessage, + }, ], - isError: true + isError: true, }; } @@ -259,11 +400,10 @@ export class McpServer { */ private async validateToolInput< Tool extends RegisteredTool, - Args extends Tool['inputSchema'] extends infer InputSchema - ? InputSchema extends AnySchema - ? SchemaOutput - : undefined + Args extends Tool["inputSchema"] extends infer InputSchema + ? InputSchema extends AnySchema ? SchemaOutput : undefined + : undefined, >(tool: Tool, args: Args, toolName: string): Promise { if (!tool.inputSchema) { return undefined as Args; @@ -275,9 +415,14 @@ export class McpServer { const schemaToParse = inputObj ?? (tool.inputSchema as AnySchema); const parseResult = await safeParseAsync(schemaToParse, args); if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const error = "error" in parseResult + ? parseResult.error + : "Unknown error"; const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`); + throw new McpError( + ErrorCode.InvalidParams, + `Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`, + ); } return parseResult.data as unknown as Args; @@ -286,13 +431,17 @@ export class McpServer { /** * Validates tool output against the tool's output schema. */ - private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { + private async validateToolOutput( + tool: RegisteredTool, + result: CallToolResult | CreateTaskResult, + toolName: string, + ): Promise { if (!tool.outputSchema) { return; } // Only validate CallToolResult, not CreateTaskResult - if (!('content' in result)) { + if (!("content" in result)) { return; } @@ -303,19 +452,26 @@ export class McpServer { if (!result.structuredContent) { throw new McpError( ErrorCode.InvalidParams, - `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` + `Output validation error: Tool ${toolName} has an output schema but no structured content was provided`, ); } // if the tool has an output schema, validate structured content - const outputObj = normalizeObjectSchema(tool.outputSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(outputObj, result.structuredContent); + const outputObj = normalizeObjectSchema( + tool.outputSchema, + ) as AnyObjectSchema; + const parseResult = await safeParseAsync( + outputObj, + result.structuredContent, + ); if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const error = "error" in parseResult + ? parseResult.error + : "Unknown error"; const errorMessage = getParseErrorMessage(error); throw new McpError( ErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}` + `Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}`, ); } } @@ -323,28 +479,43 @@ export class McpServer { /** * Executes a tool handler (either regular or task-based). */ - private async executeToolHandler( + private async executeToolHandler< + ExtraT extends RequestHandlerExtra, + >( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + extra: ExtraT, ): Promise { - const handler = tool.handler as AnyToolHandler; - const isTaskHandler = 'createTask' in handler; + const handler = tool.handler as AnyToolHandler< + ZodRawShapeCompat | undefined + >; + const isTaskHandler = "createTask" in handler; if (isTaskHandler) { if (!extra.taskStore) { - throw new Error('No task store provided.'); + throw new Error("No task store provided."); } const taskExtra = { ...extra, taskStore: extra.taskStore }; if (tool.inputSchema) { - const typedHandler = handler as ToolTaskHandler; + const typedHandler = handler as ToolTaskHandler< + ZodRawShapeCompat + >; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + return await Promise.resolve( + typedHandler.createTask( + args as any, + taskExtra as unknown as CreateTaskRequestHandlerExtra, + ), + ); } else { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); + return await Promise.resolve( + (typedHandler.createTask as any)( + taskExtra as unknown as CreateTaskRequestHandlerExtra, + ), + ); } } @@ -362,35 +533,59 @@ export class McpServer { /** * Handles automatic task polling for tools with taskSupport 'optional'. */ - private async handleAutomaticTaskPolling( + private async handleAutomaticTaskPolling< + RequestT extends CallToolRequest, + ExtraT extends RequestHandlerExtra, + >( tool: RegisteredTool, request: RequestT, - extra: RequestHandlerExtra + extra: ExtraT, ): Promise { if (!extra.taskStore) { - throw new Error('No task store provided for task-capable tool.'); + throw new Error("No task store provided for task-capable tool."); } // Validate input and create task - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const handler = tool.handler as ToolTaskHandler; + const args = await this.validateToolInput( + tool, + request.params.arguments, + request.params.name, + ); + const handler = tool.handler as ToolTaskHandler< + ZodRawShapeCompat | undefined + >; const taskExtra = { ...extra, taskStore: extra.taskStore }; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) - : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + ? await Promise.resolve( + (handler as ToolTaskHandler).createTask( + args, + taskExtra as unknown as CreateTaskRequestHandlerExtra, + ), + ) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + : await Promise.resolve( + ((handler as ToolTaskHandler).createTask as any)( + taskExtra as unknown as CreateTaskRequestHandlerExtra, + ), + ); // Poll until completion const taskId = createTaskResult.task.taskId; let task = createTaskResult.task; const pollInterval = task.pollInterval ?? 5000; - while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { - await new Promise(resolve => setTimeout(resolve, pollInterval)); + while ( + task.status !== "completed" && task.status !== "failed" && + task.status !== "cancelled" + ) { + await new Promise((resolve) => setTimeout(resolve, pollInterval)); const updatedTask = await extra.taskStore.getTask(taskId); if (!updatedTask) { - throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`); + throw new McpError( + ErrorCode.InternalError, + `Task ${taskId} not found during polling`, + ); } task = updatedTask; } @@ -406,38 +601,61 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(getMethodValue(CompleteRequestSchema)); + this.server.assertCanSetRequestHandler( + getMethodValue(CompleteRequestSchema), + ); this.server.registerCapabilities({ - completions: {} + completions: {}, }); - this.server.setRequestHandler(CompleteRequestSchema, async (request): Promise => { - switch (request.params.ref.type) { - case 'ref/prompt': - assertCompleteRequestPrompt(request); - return this.handlePromptCompletion(request, request.params.ref); - - case 'ref/resource': - assertCompleteRequestResourceTemplate(request); - return this.handleResourceCompletion(request, request.params.ref); - - default: - throw new McpError(ErrorCode.InvalidParams, `Invalid completion reference: ${request.params.ref}`); - } - }); + this.server.setRequestHandler( + CompleteRequestSchema, + async (request): Promise => { + switch (request.params.ref.type) { + case "ref/prompt": + assertCompleteRequestPrompt(request); + return this.handlePromptCompletion( + request, + request.params.ref, + ); + + case "ref/resource": + assertCompleteRequestResourceTemplate(request); + return this.handleResourceCompletion( + request, + request.params.ref, + ); + + default: + throw new McpError( + ErrorCode.InvalidParams, + `Invalid completion reference: ${request.params.ref}`, + ); + } + }, + ); this._completionHandlerInitialized = true; } - private async handlePromptCompletion(request: CompleteRequestPrompt, ref: PromptReference): Promise { + private async handlePromptCompletion( + request: CompleteRequestPrompt, + ref: PromptReference, + ): Promise { const prompt = this._registeredPrompts[ref.name]; if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} not found`); + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${ref.name} not found`, + ); } if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} disabled`); + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${ref.name} disabled`, + ); } if (!prompt.argsSchema) { @@ -454,15 +672,20 @@ export class McpServer { if (!completer) { return EMPTY_COMPLETION_RESULT; } - const suggestions = await completer(request.params.argument.value, request.params.context); + const suggestions = await completer( + request.params.argument.value, + request.params.context, + ); return createCompletionResult(suggestions); } private async handleResourceCompletion( request: CompleteRequestResourceTemplate, - ref: ResourceTemplateReference + ref: ResourceTemplateReference, ): Promise { - const template = Object.values(this._registeredResourceTemplates).find(t => t.resourceTemplate.uriTemplate.toString() === ref.uri); + const template = Object.values(this._registeredResourceTemplates).find( + (t) => t.resourceTemplate.uriTemplate.toString() === ref.uri, + ); if (!template) { if (this._registeredResources[ref.uri]) { @@ -470,15 +693,23 @@ export class McpServer { return EMPTY_COMPLETION_RESULT; } - throw new McpError(ErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); + throw new McpError( + ErrorCode.InvalidParams, + `Resource template ${request.params.ref.uri} not found`, + ); } - const completer = template.resourceTemplate.completeCallback(request.params.argument.name); + const completer = template.resourceTemplate.completeCallback( + request.params.argument.name, + ); if (!completer) { return EMPTY_COMPLETION_RESULT; } - const suggestions = await completer(request.params.argument.value, request.params.context); + const suggestions = await completer( + request.params.argument.value, + request.params.context, + ); return createCompletionResult(suggestions); } @@ -489,76 +720,162 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(getMethodValue(ListResourcesRequestSchema)); - this.server.assertCanSetRequestHandler(getMethodValue(ListResourceTemplatesRequestSchema)); - this.server.assertCanSetRequestHandler(getMethodValue(ReadResourceRequestSchema)); + this.server.assertCanSetRequestHandler( + getMethodValue(ListResourcesRequestSchema), + ); + this.server.assertCanSetRequestHandler( + getMethodValue(ListResourceTemplatesRequestSchema), + ); + this.server.assertCanSetRequestHandler( + getMethodValue(ReadResourceRequestSchema), + ); this.server.registerCapabilities({ resources: { - listChanged: true - } + listChanged: true, + }, }); - this.server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { - const resources = Object.entries(this._registeredResources) - .filter(([_, resource]) => resource.enabled) - .map(([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata - })); - - const templateResources: Resource[] = []; - for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.resourceTemplate.listCallback) { - continue; - } - - const result = await template.resourceTemplate.listCallback(extra); - for (const resource of result.resources) { - templateResources.push({ - ...template.metadata, - // the defined resource metadata should override the template metadata if present - ...resource - }); - } - } + this.server.setRequestHandler( + ListResourcesRequestSchema, + ( + request: ListResourcesRequest, + extra: RequestHandlerExtra< + ListResourcesRequest, + ServerNotification + >, + ) => this._executeRequest< + ListResourcesResult, + ListResourcesRequest, + RequestHandlerExtra + >( + async ( + request: ListResourcesRequest, + extra: RequestHandlerExtra< + ListResourcesRequest, + ServerNotification + >, + ) => { + const resources = Object.entries( + this._registeredResources, + ) + .filter(([_, resource]) => resource.enabled) + .map(([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata, + })); + + const templateResources: Resource[] = []; + for ( + const template of Object.values( + this._registeredResourceTemplates, + ) + ) { + if (!template.resourceTemplate.listCallback) { + continue; + } - return { resources: [...resources, ...templateResources] }; - }); + const result = await template.resourceTemplate + .listCallback( + extra as any, + ); + for (const resource of result.resources) { + templateResources.push({ + ...template.metadata, + // the defined resource metadata should override the template metadata if present + ...resource, + }); + } + } - this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { - const resourceTemplates = Object.entries(this._registeredResourceTemplates).map(([name, template]) => ({ - name, - uriTemplate: template.resourceTemplate.uriTemplate.toString(), - ...template.metadata - })); + return { + resources: [...resources, ...templateResources], + }; + }, + request, + extra, + ), + ); - return { resourceTemplates }; - }); + this.server.setRequestHandler( + ListResourceTemplatesRequestSchema, + async () => { + const resourceTemplates = Object.entries( + this._registeredResourceTemplates, + ).map(([name, template]) => ({ + name, + uriTemplate: template.resourceTemplate.uriTemplate + .toString(), + ...template.metadata, + })); - this.server.setRequestHandler(ReadResourceRequestSchema, async (request, extra) => { - const uri = new URL(request.params.uri); + return { resourceTemplates }; + }, + ); - // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; - if (resource) { - if (!resource.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); - } - return resource.readCallback(uri, extra); - } + this.server.setRequestHandler( + ReadResourceRequestSchema, + ( + request: ReadResourceRequest, + extra: RequestHandlerExtra< + ReadResourceRequest, + ServerNotification + >, + ) => this._executeRequest< + ReadResourceResult, + ReadResourceRequest, + RequestHandlerExtra + >( + async ( + request: ReadResourceRequest, + extra: RequestHandlerExtra< + ReadResourceRequest, + ServerNotification + >, + ) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + if (!resource.enabled) { + throw new McpError( + ErrorCode.InvalidParams, + `Resource ${uri} disabled`, + ); + } + return resource.readCallback(uri, extra as any); + } - // Then check templates - for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); - if (variables) { - return template.readCallback(uri, variables, extra); - } - } + // Then check templates + for ( + const template of Object.values( + this._registeredResourceTemplates, + ) + ) { + const variables = template.resourceTemplate + .uriTemplate.match( + uri.toString(), + ); + if (variables) { + return template.readCallback( + uri, + variables, + extra as any, + ); + } + } - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); - }); + throw new McpError( + ErrorCode.InvalidParams, + `Resource ${uri} not found`, + ); + }, + request, + extra, + ), + ); this._resourceHandlersInitialized = true; } @@ -570,59 +887,137 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(getMethodValue(ListPromptsRequestSchema)); - this.server.assertCanSetRequestHandler(getMethodValue(GetPromptRequestSchema)); + this.server.assertCanSetRequestHandler( + getMethodValue(ListPromptsRequestSchema), + ); + this.server.assertCanSetRequestHandler( + getMethodValue(GetPromptRequestSchema), + ); this.server.registerCapabilities({ prompts: { - listChanged: true - } + listChanged: true, + }, }); this.server.setRequestHandler( ListPromptsRequestSchema, - (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts) - .filter(([, prompt]) => prompt.enabled) - .map(([name, prompt]): Prompt => { - return { - name, - title: prompt.title, - description: prompt.description, - arguments: prompt.argsSchema ? promptArgumentsFromSchema(prompt.argsSchema) : undefined - }; - }) - }) + ( + request: ListPromptsRequest, + extra: RequestHandlerExtra< + ListPromptsRequest, + ServerNotification + >, + ) => this._executeRequest< + ListPromptsResult, + ListPromptsRequest, + RequestHandlerExtra + >( + (): Promise => + Promise.resolve({ + prompts: Object.entries(this._registeredPrompts) + .filter(([, prompt]) => prompt.enabled) + .map(([name, prompt]): Prompt => { + return { + name, + title: prompt.title, + description: prompt.description, + arguments: prompt.argsSchema + ? promptArgumentsFromSchema( + prompt.argsSchema, + ) + : undefined, + }; + }), + }), + request, + extra, + ), ); - this.server.setRequestHandler(GetPromptRequestSchema, async (request, extra): Promise => { - const prompt = this._registeredPrompts[request.params.name]; - if (!prompt) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); - } + this.server.setRequestHandler( + GetPromptRequestSchema, + ( + request: GetPromptRequest, + extra: RequestHandlerExtra< + GetPromptRequest, + ServerNotification + >, + ) => this._executeRequest< + GetPromptResult, + GetPromptRequest, + RequestHandlerExtra + >( + async ( + request: GetPromptRequest, + extra: RequestHandlerExtra< + GetPromptRequest, + ServerNotification + >, + ): Promise => { + const prompt = this._registeredPrompts[request.params.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.name} not found`, + ); + } - if (!prompt.enabled) { - throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); - } + if (!prompt.enabled) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.name} disabled`, + ); + } - if (prompt.argsSchema) { - const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(argsObj, request.params.arguments); - if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; - const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); - } + if (prompt.argsSchema) { + const argsObj = normalizeObjectSchema( + prompt.argsSchema, + ) as AnyObjectSchema; + const parseResult = await safeParseAsync( + argsObj, + request.params.arguments, + ); + if (!parseResult.success) { + const error = "error" in parseResult + ? parseResult.error + : "Unknown error"; + const errorMessage = getParseErrorMessage(error); + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`, + ); + } - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); - } else { - const cb = prompt.callback as PromptCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(extra)); - } - }); + const args = parseResult.data; + const cb = prompt.callback as PromptCallback< + PromptArgsRawShape + >; + return await Promise.resolve( + cb( + args, + extra as unknown as RequestHandlerExtra< + ServerRequest, + ServerNotification + >, + ), + ); + } else { + const cb = prompt.callback as PromptCallback; + return await Promise.resolve( + cb( + extra as unknown as RequestHandlerExtra< + ServerRequest, + ServerNotification + >, + ), + ); + } + }, + request, + extra, + ), + ); this._promptHandlersInitialized = true; } @@ -631,19 +1026,32 @@ export class McpServer { * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. * @deprecated Use `registerResource` instead. */ - resource(name: string, uri: string, readCallback: ReadResourceCallback): RegisteredResource; + resource( + name: string, + uri: string, + readCallback: ReadResourceCallback, + ): RegisteredResource; /** * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. * @deprecated Use `registerResource` instead. */ - resource(name: string, uri: string, metadata: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + resource( + name: string, + uri: string, + metadata: ResourceMetadata, + readCallback: ReadResourceCallback, + ): RegisteredResource; /** * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. * @deprecated Use `registerResource` instead. */ - resource(name: string, template: ResourceTemplate, readCallback: ReadResourceTemplateCallback): RegisteredResourceTemplate; + resource( + name: string, + template: ResourceTemplate, + readCallback: ReadResourceTemplateCallback, + ): RegisteredResourceTemplate; /** * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. @@ -653,20 +1061,28 @@ export class McpServer { name: string, template: ResourceTemplate, metadata: ResourceMetadata, - readCallback: ReadResourceTemplateCallback + readCallback: ReadResourceTemplateCallback, ): RegisteredResourceTemplate; - resource(name: string, uriOrTemplate: string | ResourceTemplate, ...rest: unknown[]): RegisteredResource | RegisteredResourceTemplate { + resource( + name: string, + uriOrTemplate: string | ResourceTemplate, + ...rest: unknown[] + ): RegisteredResource | RegisteredResourceTemplate { let metadata: ResourceMetadata | undefined; - if (typeof rest[0] === 'object') { + if (typeof rest[0] === "object") { metadata = rest.shift() as ResourceMetadata; } - const readCallback = rest[0] as ReadResourceCallback | ReadResourceTemplateCallback; + const readCallback = rest[0] as + | ReadResourceCallback + | ReadResourceTemplateCallback; - if (typeof uriOrTemplate === 'string') { + if (typeof uriOrTemplate === "string") { if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); + throw new Error( + `Resource ${uriOrTemplate} is already registered`, + ); } const registeredResource = this._createRegisteredResource( @@ -674,7 +1090,7 @@ export class McpServer { undefined, uriOrTemplate, metadata, - readCallback as ReadResourceCallback + readCallback as ReadResourceCallback, ); this.setResourceRequestHandlers(); @@ -682,16 +1098,19 @@ export class McpServer { return registeredResource; } else { if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); + throw new Error( + `Resource template ${name} is already registered`, + ); } - const registeredResourceTemplate = this._createRegisteredResourceTemplate( - name, - undefined, - uriOrTemplate, - metadata, - readCallback as ReadResourceTemplateCallback - ); + const registeredResourceTemplate = this + ._createRegisteredResourceTemplate( + name, + undefined, + uriOrTemplate, + metadata, + readCallback as ReadResourceTemplateCallback, + ); this.setResourceRequestHandlers(); this.sendResourceListChanged(); @@ -703,22 +1122,29 @@ export class McpServer { * Registers a resource with a config object and callback. * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. */ - registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: string, + config: ResourceMetadata, + readCallback: ReadResourceCallback, + ): RegisteredResource; registerResource( name: string, uriOrTemplate: ResourceTemplate, config: ResourceMetadata, - readCallback: ReadResourceTemplateCallback + readCallback: ReadResourceTemplateCallback, ): RegisteredResourceTemplate; registerResource( name: string, uriOrTemplate: string | ResourceTemplate, config: ResourceMetadata, - readCallback: ReadResourceCallback | ReadResourceTemplateCallback + readCallback: ReadResourceCallback | ReadResourceTemplateCallback, ): RegisteredResource | RegisteredResourceTemplate { - if (typeof uriOrTemplate === 'string') { + if (typeof uriOrTemplate === "string") { if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); + throw new Error( + `Resource ${uriOrTemplate} is already registered`, + ); } const registeredResource = this._createRegisteredResource( @@ -726,7 +1152,7 @@ export class McpServer { (config as BaseMetadata).title, uriOrTemplate, config, - readCallback as ReadResourceCallback + readCallback as ReadResourceCallback, ); this.setResourceRequestHandlers(); @@ -734,16 +1160,19 @@ export class McpServer { return registeredResource; } else { if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); + throw new Error( + `Resource template ${name} is already registered`, + ); } - const registeredResourceTemplate = this._createRegisteredResourceTemplate( - name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceTemplateCallback - ); + const registeredResourceTemplate = this + ._createRegisteredResourceTemplate( + name, + (config as BaseMetadata).title, + uriOrTemplate, + config, + readCallback as ReadResourceTemplateCallback, + ); this.setResourceRequestHandlers(); this.sendResourceListChanged(); @@ -756,7 +1185,7 @@ export class McpServer { title: string | undefined, uri: string, metadata: ResourceMetadata | undefined, - readCallback: ReadResourceCallback + readCallback: ReadResourceCallback, ): RegisteredResource { const registeredResource: RegisteredResource = { name, @@ -767,18 +1196,31 @@ export class McpServer { disable: () => registeredResource.update({ enabled: false }), enable: () => registeredResource.update({ enabled: true }), remove: () => registeredResource.update({ uri: null }), - update: updates => { - if (typeof updates.uri !== 'undefined' && updates.uri !== uri) { + update: (updates) => { + if (typeof updates.uri !== "undefined" && updates.uri !== uri) { delete this._registeredResources[uri]; - if (updates.uri) this._registeredResources[updates.uri] = registeredResource; + if (updates.uri) { + this._registeredResources[updates.uri] = + registeredResource; + } + } + if (typeof updates.name !== "undefined") { + registeredResource.name = updates.name; + } + if (typeof updates.title !== "undefined") { + registeredResource.title = updates.title; + } + if (typeof updates.metadata !== "undefined") { + registeredResource.metadata = updates.metadata; + } + if (typeof updates.callback !== "undefined") { + registeredResource.readCallback = updates.callback; + } + if (typeof updates.enabled !== "undefined") { + registeredResource.enabled = updates.enabled; } - if (typeof updates.name !== 'undefined') registeredResource.name = updates.name; - if (typeof updates.title !== 'undefined') registeredResource.title = updates.title; - if (typeof updates.metadata !== 'undefined') registeredResource.metadata = updates.metadata; - if (typeof updates.callback !== 'undefined') registeredResource.readCallback = updates.callback; - if (typeof updates.enabled !== 'undefined') registeredResource.enabled = updates.enabled; this.sendResourceListChanged(); - } + }, }; this._registeredResources[uri] = registeredResource; return registeredResource; @@ -789,7 +1231,7 @@ export class McpServer { title: string | undefined, template: ResourceTemplate, metadata: ResourceMetadata | undefined, - readCallback: ReadResourceTemplateCallback + readCallback: ReadResourceTemplateCallback, ): RegisteredResourceTemplate { const registeredResourceTemplate: RegisteredResourceTemplate = { resourceTemplate: template, @@ -797,27 +1239,45 @@ export class McpServer { metadata, readCallback, enabled: true, - disable: () => registeredResourceTemplate.update({ enabled: false }), + disable: () => + registeredResourceTemplate.update({ enabled: false }), enable: () => registeredResourceTemplate.update({ enabled: true }), remove: () => registeredResourceTemplate.update({ name: null }), - update: updates => { - if (typeof updates.name !== 'undefined' && updates.name !== name) { + update: (updates) => { + if ( + typeof updates.name !== "undefined" && updates.name !== name + ) { delete this._registeredResourceTemplates[name]; - if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; + if (updates.name) { + this._registeredResourceTemplates[updates.name] = + registeredResourceTemplate; + } + } + if (typeof updates.title !== "undefined") { + registeredResourceTemplate.title = updates.title; + } + if (typeof updates.template !== "undefined") { + registeredResourceTemplate.resourceTemplate = + updates.template; + } + if (typeof updates.metadata !== "undefined") { + registeredResourceTemplate.metadata = updates.metadata; + } + if (typeof updates.callback !== "undefined") { + registeredResourceTemplate.readCallback = updates.callback; + } + if (typeof updates.enabled !== "undefined") { + registeredResourceTemplate.enabled = updates.enabled; } - if (typeof updates.title !== 'undefined') registeredResourceTemplate.title = updates.title; - if (typeof updates.template !== 'undefined') registeredResourceTemplate.resourceTemplate = updates.template; - if (typeof updates.metadata !== 'undefined') registeredResourceTemplate.metadata = updates.metadata; - if (typeof updates.callback !== 'undefined') registeredResourceTemplate.readCallback = updates.callback; - if (typeof updates.enabled !== 'undefined') registeredResourceTemplate.enabled = updates.enabled; this.sendResourceListChanged(); - } + }, }; this._registeredResourceTemplates[name] = registeredResourceTemplate; // If the resource template has any completion callbacks, enable completions capability const variableNames = template.uriTemplate.variableNames; - const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!template.completeCallback(v)); + const hasCompleter = Array.isArray(variableNames) && + variableNames.some((v) => !!template.completeCallback(v)); if (hasCompleter) { this.setCompletionRequestHandler(); } @@ -830,36 +1290,57 @@ export class McpServer { title: string | undefined, description: string | undefined, argsSchema: PromptArgsRawShape | undefined, - callback: PromptCallback + callback: PromptCallback, ): RegisteredPrompt { const registeredPrompt: RegisteredPrompt = { title, description, - argsSchema: argsSchema === undefined ? undefined : objectFromShape(argsSchema), + argsSchema: argsSchema === undefined + ? undefined + : objectFromShape(argsSchema), callback, enabled: true, disable: () => registeredPrompt.update({ enabled: false }), enable: () => registeredPrompt.update({ enabled: true }), remove: () => registeredPrompt.update({ name: null }), - update: updates => { - if (typeof updates.name !== 'undefined' && updates.name !== name) { + update: (updates) => { + if ( + typeof updates.name !== "undefined" && updates.name !== name + ) { delete this._registeredPrompts[name]; - if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt; + if (updates.name) { + this._registeredPrompts[updates.name] = + registeredPrompt; + } + } + if (typeof updates.title !== "undefined") { + registeredPrompt.title = updates.title; + } + if (typeof updates.description !== "undefined") { + registeredPrompt.description = updates.description; + } + if (typeof updates.argsSchema !== "undefined") { + registeredPrompt.argsSchema = objectFromShape( + updates.argsSchema, + ); + } + if (typeof updates.callback !== "undefined") { + registeredPrompt.callback = updates.callback; + } + if (typeof updates.enabled !== "undefined") { + registeredPrompt.enabled = updates.enabled; } - if (typeof updates.title !== 'undefined') registeredPrompt.title = updates.title; - if (typeof updates.description !== 'undefined') registeredPrompt.description = updates.description; - if (typeof updates.argsSchema !== 'undefined') registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); - if (typeof updates.callback !== 'undefined') registeredPrompt.callback = updates.callback; - if (typeof updates.enabled !== 'undefined') registeredPrompt.enabled = updates.enabled; this.sendPromptListChanged(); - } + }, }; this._registeredPrompts[name] = registeredPrompt; // If any argument uses a Completable schema, enable completions capability if (argsSchema) { - const hasCompletable = Object.values(argsSchema).some(field => { - const inner: unknown = field instanceof ZodOptional ? field._def?.innerType : field; + const hasCompletable = Object.values(argsSchema).some((field) => { + const inner: unknown = field instanceof ZodOptional + ? field._def?.innerType + : field; return isCompletable(inner); }); if (hasCompletable) { @@ -879,7 +1360,7 @@ export class McpServer { annotations: ToolAnnotations | undefined, execution: ToolExecution | undefined, _meta: Record | undefined, - handler: AnyToolHandler + handler: AnyToolHandler, ): RegisteredTool { // Validate tool name according to SEP specification validateAndWarnToolName(name); @@ -897,24 +1378,48 @@ export class McpServer { disable: () => registeredTool.update({ enabled: false }), enable: () => registeredTool.update({ enabled: true }), remove: () => registeredTool.update({ name: null }), - update: updates => { - if (typeof updates.name !== 'undefined' && updates.name !== name) { - if (typeof updates.name === 'string') { + update: (updates) => { + if ( + typeof updates.name !== "undefined" && updates.name !== name + ) { + if (typeof updates.name === "string") { validateAndWarnToolName(updates.name); } delete this._registeredTools[name]; - if (updates.name) this._registeredTools[updates.name] = registeredTool; + if (updates.name) { + this._registeredTools[updates.name] = registeredTool; + } + } + if (typeof updates.title !== "undefined") { + registeredTool.title = updates.title; + } + if (typeof updates.description !== "undefined") { + registeredTool.description = updates.description; + } + if (typeof updates.paramsSchema !== "undefined") { + registeredTool.inputSchema = objectFromShape( + updates.paramsSchema, + ); + } + if (typeof updates.outputSchema !== "undefined") { + registeredTool.outputSchema = objectFromShape( + updates.outputSchema, + ); + } + if (typeof updates.callback !== "undefined") { + registeredTool.handler = updates.callback; + } + if (typeof updates.annotations !== "undefined") { + registeredTool.annotations = updates.annotations; + } + if (typeof updates._meta !== "undefined") { + registeredTool._meta = updates._meta; + } + if (typeof updates.enabled !== "undefined") { + registeredTool.enabled = updates.enabled; } - if (typeof updates.title !== 'undefined') registeredTool.title = updates.title; - if (typeof updates.description !== 'undefined') registeredTool.description = updates.description; - if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = objectFromShape(updates.paramsSchema); - if (typeof updates.outputSchema !== 'undefined') registeredTool.outputSchema = objectFromShape(updates.outputSchema); - if (typeof updates.callback !== 'undefined') registeredTool.handler = updates.callback; - if (typeof updates.annotations !== 'undefined') registeredTool.annotations = updates.annotations; - if (typeof updates._meta !== 'undefined') registeredTool._meta = updates._meta; - if (typeof updates.enabled !== 'undefined') registeredTool.enabled = updates.enabled; this.sendToolListChanged(); - } + }, }; this._registeredTools[name] = registeredTool; @@ -947,7 +1452,7 @@ export class McpServer { tool( name: string, paramsSchemaOrAnnotations: Args | ToolAnnotations, - cb: ToolCallback + cb: ToolCallback, ): RegisteredTool; /** @@ -963,7 +1468,7 @@ export class McpServer { name: string, description: string, paramsSchemaOrAnnotations: Args | ToolAnnotations, - cb: ToolCallback + cb: ToolCallback, ): RegisteredTool; /** @@ -974,7 +1479,7 @@ export class McpServer { name: string, paramsSchema: Args, annotations: ToolAnnotations, - cb: ToolCallback + cb: ToolCallback, ): RegisteredTool; /** @@ -986,7 +1491,7 @@ export class McpServer { description: string, paramsSchema: Args, annotations: ToolAnnotations, - cb: ToolCallback + cb: ToolCallback, ): RegisteredTool; /** @@ -1006,7 +1511,7 @@ export class McpServer { // Support for this style is frozen as of protocol version 2025-03-26. Future additions // to tool definition should *NOT* be added. - if (typeof rest[0] === 'string') { + if (typeof rest[0] === "string") { description = rest.shift() as string; } @@ -1020,12 +1525,15 @@ export class McpServer { inputSchema = rest.shift() as ZodRawShapeCompat; // Check if the next arg is potentially annotations - if (rest.length > 1 && typeof rest[0] === 'object' && rest[0] !== null && !isZodRawShapeCompat(rest[0])) { + if ( + rest.length > 1 && typeof rest[0] === "object" && + rest[0] !== null && !isZodRawShapeCompat(rest[0]) + ) { // Case: tool(name, paramsSchema, annotations, cb) // Or: tool(name, description, paramsSchema, annotations, cb) annotations = rest.shift() as ToolAnnotations; } - } else if (typeof firstArg === 'object' && firstArg !== null) { + } else if (typeof firstArg === "object" && firstArg !== null) { // Not a ZodRawShapeCompat, so must be annotations in this position // Case: tool(name, annotations, cb) // Or: tool(name, description, annotations, cb) @@ -1041,16 +1549,19 @@ export class McpServer { inputSchema, outputSchema, annotations, - { taskSupport: 'forbidden' }, + { taskSupport: "forbidden" }, undefined, - callback + callback, ); } /** * Registers a tool with a config object and callback. */ - registerTool( + registerTool< + OutputArgs extends ZodRawShapeCompat | AnySchema, + InputArgs extends undefined | ZodRawShapeCompat | AnySchema = undefined, + >( name: string, config: { title?: string; @@ -1060,13 +1571,20 @@ export class McpServer { annotations?: ToolAnnotations; _meta?: Record; }, - cb: ToolCallback + cb: ToolCallback, ): RegisteredTool { if (this._registeredTools[name]) { throw new Error(`Tool ${name} is already registered`); } - const { title, description, inputSchema, outputSchema, annotations, _meta } = config; + const { + title, + description, + inputSchema, + outputSchema, + annotations, + _meta, + } = config; return this._createRegisteredTool( name, @@ -1075,9 +1593,9 @@ export class McpServer { inputSchema, outputSchema, annotations, - { taskSupport: 'forbidden' }, + { taskSupport: "forbidden" }, _meta, - cb as ToolCallback + cb as ToolCallback, ); } @@ -1091,13 +1609,21 @@ export class McpServer { * Registers a zero-argument prompt `name` (with a description) which will run the given function when the client calls it. * @deprecated Use `registerPrompt` instead. */ - prompt(name: string, description: string, cb: PromptCallback): RegisteredPrompt; + prompt( + name: string, + description: string, + cb: PromptCallback, + ): RegisteredPrompt; /** * Registers a prompt `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. * @deprecated Use `registerPrompt` instead. */ - prompt(name: string, argsSchema: Args, cb: PromptCallback): RegisteredPrompt; + prompt( + name: string, + argsSchema: Args, + cb: PromptCallback, + ): RegisteredPrompt; /** * Registers a prompt `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. @@ -1107,7 +1633,7 @@ export class McpServer { name: string, description: string, argsSchema: Args, - cb: PromptCallback + cb: PromptCallback, ): RegisteredPrompt; prompt(name: string, ...rest: unknown[]): RegisteredPrompt { @@ -1116,7 +1642,7 @@ export class McpServer { } let description: string | undefined; - if (typeof rest[0] === 'string') { + if (typeof rest[0] === "string") { description = rest.shift() as string; } @@ -1126,7 +1652,13 @@ export class McpServer { } const cb = rest[0] as PromptCallback; - const registeredPrompt = this._createRegisteredPrompt(name, undefined, description, argsSchema, cb); + const registeredPrompt = this._createRegisteredPrompt( + name, + undefined, + description, + argsSchema, + cb, + ); this.setPromptRequestHandlers(); this.sendPromptListChanged(); @@ -1144,7 +1676,7 @@ export class McpServer { description?: string; argsSchema?: Args; }, - cb: PromptCallback + cb: PromptCallback, ): RegisteredPrompt { if (this._registeredPrompts[name]) { throw new Error(`Prompt ${name} is already registered`); @@ -1157,7 +1689,7 @@ export class McpServer { title, description, argsSchema, - cb as PromptCallback + cb as PromptCallback, ); this.setPromptRequestHandlers(); @@ -1181,7 +1713,10 @@ export class McpServer { * @param params * @param sessionId optional for stateless and backward compatibility */ - async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { + async sendLoggingMessage( + params: LoggingMessageNotification["params"], + sessionId?: string, + ) { return this.server.sendLoggingMessage(params, sessionId); } /** @@ -1210,6 +1745,79 @@ export class McpServer { this.server.sendPromptListChanged(); } } + + private async _executeRequest< + ResultT, + RequestT, + ExtraT extends RequestHandlerExtra, + >( + handler: ( + request: RequestT, + extra: ExtraT, + ) => Promise, + request: RequestT, + extra: ExtraT, + ): Promise { + this._middlewareFrozen = true; + const middleware = this._middleware; + + // Optimized path: If there are no middleware, just run the handler + if (middleware.length === 0) { + return handler(request, extra); + } + + let result: ResultT | undefined; + let handlerError: unknown; + + // Wrap the handler as the final middleware + const leafMiddleware: McpMiddleware = async (_context, _next) => { + try { + result = await handler(request, extra); + } catch (e) { + handlerError = e; + } + }; + + const chain = [...middleware, leafMiddleware]; + + // Execute the chain + // Protect against creating a context with incorrect types by casting + const context: McpMiddlewareContext = { + request: request as unknown as ServerRequest, + extra: extra as unknown as RequestHandlerExtra< + ServerRequest, + ServerNotification + >, + state: {}, + }; + + const executeChain = async (i: number): Promise => { + if (i >= chain.length) { + return; + } + const fn = chain[i] as McpMiddleware; + + let nextCalled = false; + await fn(context, async () => { + if (nextCalled) { + throw new Error( + "next() called multiple times in middleware", + ); + } + nextCalled = true; + await executeChain(i + 1); + }); + }; + + await executeChain(0); + + if (handlerError) { + throw handlerError; + } + + // Return result, asserting it exists (handlers should generally return something) + return result as ResultT; + } } /** @@ -1219,7 +1827,7 @@ export type CompleteResourceTemplateCallback = ( value: string, context?: { arguments?: Record; - } + }, ) => string[] | Promise; /** @@ -1243,9 +1851,11 @@ export class ResourceTemplate { complete?: { [variable: string]: CompleteResourceTemplateCallback; }; - } + }, ) { - this._uriTemplate = typeof uriTemplate === 'string' ? new UriTemplate(uriTemplate) : uriTemplate; + this._uriTemplate = typeof uriTemplate === "string" + ? new UriTemplate(uriTemplate) + : uriTemplate; } /** @@ -1265,7 +1875,9 @@ export class ResourceTemplate { /** * Gets the callback for completing a specific URI template variable, if one was provided. */ - completeCallback(variable: string): CompleteResourceTemplateCallback | undefined { + completeCallback( + variable: string, + ): CompleteResourceTemplateCallback | undefined { return this._callbacks.complete?.[variable]; } } @@ -1273,12 +1885,16 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, Extra extends RequestHandlerExtra, - Args extends undefined | ZodRawShapeCompat | AnySchema -> = Args extends ZodRawShapeCompat - ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise - : Args extends AnySchema - ? (args: SchemaOutput, extra: Extra) => SendResultT | Promise - : (extra: Extra) => SendResultT | Promise; + Args extends undefined | ZodRawShapeCompat | AnySchema, +> = Args extends ZodRawShapeCompat ? ( + args: ShapeOutput, + extra: Extra, + ) => SendResultT | Promise + : Args extends AnySchema ? ( + args: SchemaOutput, + extra: Extra, + ) => SendResultT | Promise + : (extra: Extra) => SendResultT | Promise; /** * Callback for a tool handler registered with Server.tool(). @@ -1290,7 +1906,9 @@ export type BaseToolCallback< * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = BaseToolCallback< +export type ToolCallback< + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined, +> = BaseToolCallback< CallToolResult, RequestHandlerExtra, Args @@ -1299,7 +1917,9 @@ export type ToolCallback = ToolCallback | ToolTaskHandler; +export type AnyToolHandler< + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined, +> = ToolCallback | ToolTaskHandler; export type RegisteredTool = { title?: string; @@ -1313,7 +1933,10 @@ export type RegisteredTool = { enabled: boolean; enable(): void; disable(): void; - update(updates: { + update< + InputArgs extends ZodRawShapeCompat, + OutputArgs extends ZodRawShapeCompat, + >(updates: { name?: string | null; title?: string; description?: string; @@ -1328,8 +1951,8 @@ export type RegisteredTool = { }; const EMPTY_OBJECT_JSON_SCHEMA = { - type: 'object' as const, - properties: {} + type: "object" as const, + properties: {}, }; /** @@ -1338,11 +1961,11 @@ const EMPTY_OBJECT_JSON_SCHEMA = { function isZodTypeLike(value: unknown): value is AnySchema { return ( value !== null && - typeof value === 'object' && - 'parse' in value && - typeof value.parse === 'function' && - 'safeParse' in value && - typeof value.safeParse === 'function' + typeof value === "object" && + "parse" in value && + typeof value.parse === "function" && + "safeParse" in value && + typeof value.safeParse === "function" ); } @@ -1356,7 +1979,7 @@ function isZodTypeLike(value: unknown): value is AnySchema { * This includes transformed schemas like z.preprocess(), z.transform(), z.pipe(). */ function isZodSchemaInstance(obj: object): boolean { - return '_def' in obj || '_zod' in obj || isZodTypeLike(obj); + return "_def" in obj || "_zod" in obj || isZodTypeLike(obj); } /** @@ -1368,7 +1991,7 @@ function isZodSchemaInstance(obj: object): boolean { * which have internal properties that could be mistaken for schema values. */ function isZodRawShapeCompat(obj: unknown): obj is ZodRawShapeCompat { - if (typeof obj !== 'object' || obj === null) { + if (typeof obj !== "object" || obj === null) { return false; } @@ -1390,7 +2013,9 @@ function isZodRawShapeCompat(obj: unknown): obj is ZodRawShapeCompat { * Converts a provided Zod schema to a Zod object if it is a ZodRawShapeCompat, * otherwise returns the schema as is. */ -function getZodSchemaObject(schema: ZodRawShapeCompat | AnySchema | undefined): AnySchema | undefined { +function getZodSchemaObject( + schema: ZodRawShapeCompat | AnySchema | undefined, +): AnySchema | undefined { if (!schema) { return undefined; } @@ -1405,13 +2030,13 @@ function getZodSchemaObject(schema: ZodRawShapeCompat | AnySchema | undefined): /** * Additional, optional information for annotating a resource. */ -export type ResourceMetadata = Omit; +export type ResourceMetadata = Omit; /** * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra + extra: RequestHandlerExtra, ) => ListResourcesResult | Promise; /** @@ -1419,7 +2044,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1447,7 +2072,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1471,9 +2096,15 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; -export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; +export type PromptCallback< + Args extends undefined | PromptArgsRawShape = undefined, +> = Args extends PromptArgsRawShape ? ( + args: ShapeOutput, + extra: RequestHandlerExtra, + ) => GetPromptResult | Promise + : ( + extra: RequestHandlerExtra, + ) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; @@ -1505,7 +2136,7 @@ function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { return { name, description, - required: !isOptional + required: !isOptional, }; }); } @@ -1514,16 +2145,16 @@ function getMethodValue(schema: AnyObjectSchema): string { const shape = getObjectShape(schema); const methodSchema = shape?.method as AnySchema | undefined; if (!methodSchema) { - throw new Error('Schema is missing a method literal'); + throw new Error("Schema is missing a method literal"); } // Extract literal value - works for both v3 and v4 const value = getLiteralValue(methodSchema); - if (typeof value === 'string') { + if (typeof value === "string") { return value; } - throw new Error('Schema method literal must be a string'); + throw new Error("Schema method literal must be a string"); } function createCompletionResult(suggestions: string[]): CompleteResult { @@ -1531,14 +2162,14 @@ function createCompletionResult(suggestions: string[]): CompleteResult { completion: { values: suggestions.slice(0, 100), total: suggestions.length, - hasMore: suggestions.length > 100 - } + hasMore: suggestions.length > 100, + }, }; } const EMPTY_COMPLETION_RESULT: CompleteResult = { completion: { values: [], - hasMore: false - } + hasMore: false, + }, }; diff --git a/packages/server/test/server/mcpServer.test.ts b/packages/server/test/server/mcpServer.test.ts new file mode 100644 index 000000000..36ed7f141 --- /dev/null +++ b/packages/server/test/server/mcpServer.test.ts @@ -0,0 +1,543 @@ +import { McpServer } from "../../src/server/mcp.js"; +import { JSONRPCMessage } from "@modelcontextprotocol/core"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +describe("McpServer Middleware", () => { + let server: McpServer; + + beforeEach(() => { + server = new McpServer({ + name: "test-server", + version: "1.0.0", + }); + }); + + // Helper to simulate a tool call and capture the response + async function simulateCallTool(toolName: string): Promise { + let serverOnMessage: (message: any) => Promise; + let capturedResponse: JSONRPCMessage | undefined; + let resolveSend: () => void; + const sendPromise = new Promise((resolve) => { + resolveSend = resolve; + }); + + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(async (msg) => { + capturedResponse = msg as JSONRPCMessage; + resolveSend(); + }), + close: vi.fn(), + set onmessage(handler: any) { + serverOnMessage = handler; + }, + }; + + await server.connect(transport); + + const request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: toolName, + arguments: {}, + }, + }; + + if (!serverOnMessage!) { + throw new Error("Server did not attach onMessage listener"); + } + + // Trigger request + serverOnMessage(request); + + // Wait for response + await sendPromise; + + return capturedResponse!; + } + + it("should execute middleware in registration order (Onion model)", async () => { + const sequence: string[] = []; + + server.use(async (context, next) => { + sequence.push("mw1 start"); + await next(); + sequence.push("mw1 end"); + }); + + server.use(async (context, next) => { + sequence.push("mw2 start"); + await next(); + sequence.push("mw2 end"); + }); + + server.tool("test-tool", {}, async () => { + sequence.push("handler"); + return { content: [{ type: "text", text: "result" }] }; + }); + + await simulateCallTool("test-tool"); + + expect(sequence).toEqual([ + "mw1 start", + "mw2 start", + "handler", + "mw2 end", + "mw1 end", + ]); + }); + + it("should short-circuit if next() is not called", async () => { + const sequence: string[] = []; + + server.use(async (context, next) => { + sequence.push("mw1 start"); + // next() NOT called + sequence.push("mw1 end"); + }); + + server.use(async (context, next) => { + sequence.push("mw2 start"); + await next(); + }); + + server.tool("test-tool", {}, async () => { + sequence.push("handler"); + return { content: [{ type: "text", text: "result" }] }; + }); + + await simulateCallTool("test-tool"); + + // mw2 and handler should NOT run + expect(sequence).toEqual(["mw1 start", "mw1 end"]); + }); + + it("should allow middleware to communicate via ctx.state", async () => { + const server = new McpServer({ name: "test", version: "1.0" }); + server.use(async (ctx, next) => { + ctx.state.value = 1; + await next(); + }); + server.use(async (ctx, next) => { + ctx.state.value = (ctx.state.value as number) + 1; + await next(); + }); + + // Use a tool list request to trigger the chain + server.tool( + "test-tool", + {}, + async () => ({ content: [{ type: "text", text: "ok" }] }), + ); + + let capturedState: any; + server.use(async (ctx, next) => { + capturedState = ctx.state; + await next(); + }); + + let resolveSend: () => void; + const sendPromise = new Promise((resolve) => { + resolveSend = resolve; + }); + + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(async () => { + resolveSend(); + }), + close: vi.fn(), + }; + await server.connect(transport as any); + // @ts-ignore + const onMsg = (server.server.transport as any).onmessage; + onMsg({ + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { name: "test-tool", arguments: {} }, + }); + + await sendPromise; + + expect(capturedState).toBeDefined(); + expect(capturedState.value).toBe(2); + }); + + it("should execute middleware for other methods (e.g. tools/list)", async () => { + // For this check, we need to simulate tools/list. + // We can adapt our helper or just copy-paste a simplified version here for variety. + const sequence: string[] = []; + server.use(async (context, next) => { + sequence.push("mw"); + await next(); + }); + + // Register a dummy tool to ensure tools/list handler is set up + server.tool("dummy", {}, async () => ({ content: [] })); + + let serverOnMessage: any; + let resolveSend: any; + const p = new Promise((r) => resolveSend = r); + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(() => resolveSend()), + close: vi.fn(), + set onmessage(h: any) { + serverOnMessage = h; + }, + }; + await server.connect(transport); + + serverOnMessage({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: {}, + }); + await p; + + expect(sequence).toEqual(["mw"]); + }); + + it("should allow middleware to catch errors from downstream", async () => { + server.use(async (context, next) => { + try { + await next(); + } catch (e) { + // Suppress error + } + }); + + server.tool("error-tool", {}, async () => { + throw new Error("Boom"); + }); + + const response = await simulateCallTool("error-tool"); + + // Since middleware swallowed the error, the handler returns undefined (or whatever executed). + // Actually, if handler throws and middleware catches, `result` in `_executeRequest` will be undefined. + // The server transport might expect a result. + // Typescript core SDK might throw if result is missing maybe? + // Or it sends a success response with "undefined"? + + // Let's check what response we got. If error was swallowed, it shouldn't be an error response. + expect((response as any).error).toBeUndefined(); + }); + + it("should propagate errors if middleware throws", async () => { + server.use(async (context, next) => { + throw new Error("Middleware Error"); + }); + + server.tool("test-tool", {}, async () => ({ content: [] })); + + const response = await simulateCallTool("test-tool"); + + // Standard JSON-RPC error response + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain("Middleware Error"); + }); + + it("should throw an error if next() is called multiple times", async () => { + server.use(async (context, next) => { + await next(); + await next(); // Second call should throw + }); + + server.tool("test-tool", {}, async () => ({ content: [] })); + + const response = await simulateCallTool("test-tool"); + + // Expect an error response due to double-call + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain( + "next() called multiple times", + ); + }); + + it("should respect async timing (middleware can await)", async () => { + const sequence: string[] = []; + const delay = (ms: number) => + new Promise((resolve) => setTimeout(resolve, ms)); + + server.use(async (context, next) => { + sequence.push("mw1 start"); + await delay(10); // Wait 10ms + sequence.push("mw1 after delay"); + await next(); + sequence.push("mw1 end"); + }); + + server.use(async (context, next) => { + sequence.push("mw2 start"); + await next(); + }); + + server.tool("test-tool", {}, async () => { + sequence.push("handler"); + return { content: [] }; + }); + + await simulateCallTool("test-tool"); + + expect(sequence).toEqual([ + "mw1 start", + "mw1 after delay", + "mw2 start", + "handler", + "mw1 end", + ]); + }); + + it("should throw an error if use() is called after connect()", async () => { + const transport = { + start: vi.fn(), + send: vi.fn(), + close: vi.fn(), + set onmessage(_handler: any) {}, + }; + + await server.connect(transport); + + // Trying to register middleware after connect should throw + expect(() => { + server.use(async (context, next) => { + await next(); + }); + }).toThrow("Cannot register middleware after the server has started"); + }); + + // ============================================================ + // Real World Use Case Integration Tests + // ============================================================ + + describe("Real World Use Cases", () => { + it("Logging: should observe request method and capture response timing", async () => { + const logs: { method: string; durationMs: number }[] = []; + + server.use(async (context, next) => { + const start = Date.now(); + const method = (context.request as any).method || "unknown"; + + await next(); + + const durationMs = Date.now() - start; + logs.push({ method, durationMs }); + }); + + server.tool("fast-tool", {}, async () => { + return { content: [{ type: "text", text: "done" }] }; + }); + + await simulateCallTool("fast-tool"); + + expect(logs).toHaveLength(1); + expect(logs[0]!.method).toBe("tools/call"); + expect(logs[0]!.durationMs).toBeGreaterThanOrEqual(0); + }); + + it("Auth: should short-circuit unauthorized requests", async () => { + const VALID_TOKEN = "secret-token"; + + server.use(async (context, next) => { + // Simulate checking for an auth token in extra/authInfo + const authInfo = (context.extra as any)?.authInfo; + + // In real usage, authInfo would come from the transport. + // For this test, we simulate by checking a header-like property. + // Since we can't inject authInfo easily, we'll check a custom property. + const token = (context.request as any).params?._authToken; + + if (token !== VALID_TOKEN) { + // Short-circuit: don't call next(), effectively blocking the request + // In a real scenario, you might throw an error or set a response + throw new Error("Unauthorized"); + } + + await next(); + }); + + server.tool("protected-tool", {}, async () => { + return { content: [{ type: "text", text: "secret data" }] }; + }); + + // Simulate unauthorized request (no token) + const response = await simulateCallTool("protected-tool"); + + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain("Unauthorized"); + }); + + it("Activity Aggregation: should intercept tools/list and count discoveries", async () => { + let toolListCount = 0; + let toolCallCount = 0; + + server.use(async (context, next) => { + const method = (context.request as any).method; + + if (method === "tools/list") { + toolListCount++; + } else if (method === "tools/call") { + toolCallCount++; + } + + await next(); + }); + + server.tool("my-tool", {}, async () => ({ content: [] })); + + // Simulate tools/list + let serverOnMessage: any; + let resolveSend: any; + const p = new Promise((r) => (resolveSend = r)); + const transport = { + start: vi.fn(), + send: vi.fn().mockImplementation(() => resolveSend()), + close: vi.fn(), + set onmessage(h: any) { + serverOnMessage = h; + }, + }; + await server.connect(transport); + + // First: tools/list + serverOnMessage({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: {}, + }); + await p; + + // Second: tools/call (need new promise) + let resolveSend2: any; + const p2 = new Promise((r) => (resolveSend2 = r)); + transport.send.mockImplementation(() => resolveSend2()); + + serverOnMessage({ + jsonrpc: "2.0", + id: 2, + method: "tools/call", + params: { name: "my-tool", arguments: {} }, + }); + await p2; + + expect(toolListCount).toBe(1); + expect(toolCallCount).toBe(1); + }); + }); + + // ============================================================ + // Failure Mode Verification Tests + // ============================================================ + + describe("Failure Mode Verification", () => { + it("Pre-next: error thrown before next() maps to JSON-RPC error", async () => { + server.use(async (context, next) => { + // Error thrown BEFORE calling next() + throw new Error("Pre-next failure"); + }); + + server.tool("test-tool", {}, async () => ({ content: [] })); + + const response = await simulateCallTool("test-tool"); + + // Should be a proper JSON-RPC error response + expect((response as any).jsonrpc).toBe("2.0"); + expect((response as any).id).toBe(1); + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain( + "Pre-next failure", + ); + // Server should not crash - we got a response + }); + + it("Post-next: error thrown after next() maps to JSON-RPC error", async () => { + server.use(async (context, next) => { + await next(); + // Error thrown AFTER calling next() + throw new Error("Post-next failure"); + }); + + server.tool("test-tool", {}, async () => ({ content: [] })); + + const response = await simulateCallTool("test-tool"); + + // Should be a proper JSON-RPC error response + expect((response as any).jsonrpc).toBe("2.0"); + expect((response as any).id).toBe(1); + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain( + "Post-next failure", + ); + }); + + it("Handler: error thrown in tool handler returns error result (SDK behavior)", async () => { + // No middleware - test pure handler error + server.tool("failing-tool", {}, async () => { + throw new Error("Handler failure"); + }); + + const response = await simulateCallTool("failing-tool"); + + // MCP SDK converts handler errors to result with isError: true + // (not JSON-RPC error - this is intentional SDK behavior) + expect((response as any).jsonrpc).toBe("2.0"); + expect((response as any).id).toBe(1); + expect((response as any).result).toBeDefined(); + expect((response as any).result.isError).toBe(true); + expect((response as any).result.content[0]!.text).toContain( + "Handler failure", + ); + }); + + it("Multiple middleware: error in second middleware propagates correctly", async () => { + const sequence: string[] = []; + + server.use(async (context, next) => { + sequence.push("mw1 start"); + try { + await next(); + } catch (e) { + sequence.push("mw1 caught"); + throw e; // Re-throw to propagate + } + sequence.push("mw1 end"); + }); + + server.use(async (context, next) => { + sequence.push("mw2 start"); + throw new Error("mw2 failure"); + }); + + server.tool("test-tool", {}, async () => ({ content: [] })); + + const response = await simulateCallTool("test-tool"); + + expect((response as any).error).toBeDefined(); + expect((response as any).error.message).toContain("mw2 failure"); + // Verify mw1 caught the error + expect(sequence).toContain("mw1 caught"); + // mw1 end should NOT be in sequence since error was re-thrown + expect(sequence).not.toContain("mw1 end"); + }); + + it("Error contains proper JSON-RPC error code", async () => { + server.use(async (context, next) => { + throw new Error("Generic middleware error"); + }); + + server.tool("test-tool", {}, async () => ({ content: [] })); + + const response = await simulateCallTool("test-tool"); + + expect((response as any).error).toBeDefined(); + // JSON-RPC internal error code is -32603 + expect((response as any).error.code).toBeDefined(); + expect(typeof (response as any).error.code).toBe("number"); + }); + }); +});