diff --git a/docs/concepts/elicitation/elicitation.md b/docs/concepts/elicitation/elicitation.md index 94597fa5f..78782bfbb 100644 --- a/docs/concepts/elicitation/elicitation.md +++ b/docs/concepts/elicitation/elicitation.md @@ -170,6 +170,61 @@ Here's an example implementation of how a console application might handle elici [!code-csharp[](samples/client/Program.cs?name=snippet_ElicitationHandler)] +### Multi Round-Trip Requests (MRTR) + +[MRTR](xref:mrtr) is the SEP-2322 mechanism for server-driven input requests, finalized in protocol revision `DRAFT-2026-v1`. Under the draft protocol, the server-to-client `elicitation/create` request method is removed; the recommended way to ask the user for input from a server handler is to throw and let the SDK emit an on the wire. + +> [!IMPORTANT] +> `ElicitAsync` throws `InvalidOperationException("Elicitation is not supported in stateless mode.")` whenever the server is running stateless — which includes every Streamable HTTP server under `DRAFT-2026-v1` once that revision is forced to stateless-only in a future PR. Stdio servers and current-protocol stateful Streamable HTTP servers continue to work via the legacy server-to-client `elicitation/create` request flow. For code that needs to run on stateless servers — including all `DRAFT-2026-v1` Streamable HTTP servers going forward — throw `InputRequiredException` from your handler instead. It works under both protocols and both session modes. + +For example: + +```csharp +[McpServerTool, Description("Tool that elicits via MRTR")] +public static string ElicitWithMrtr( + McpServer server, + RequestContext context) +{ + // On retry, process the client's elicitation response + if (context.Params!.InputResponses?.TryGetValue("user_input", out var response) is true) + { + var elicitResult = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + return elicitResult?.Action == "accept" + ? $"User accepted: {elicitResult.Content?.FirstOrDefault().Value}" + : "User declined."; + } + + if (!server.IsMrtrSupported) + { + return "This tool requires MRTR support (DRAFT-2026-v1, or a stateful current-protocol session)."; + } + + // First call — request user input + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_input"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm the action", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["confirm"] = new ElicitRequestParams.BooleanSchema + { + Description = "Confirm the action" + } + } + } + }) + }, + requestState: "awaiting-confirmation"); +} +``` + +> [!TIP] +> See [Multi Round-Trip Requests (MRTR)](xref:mrtr) for the full protocol details, including multiple round trips, concurrent input requests, and the compatibility matrix. + ### URL Elicitation Required Error When a tool cannot proceed without first completing a URL-mode elicitation (for example, when third-party OAuth authorization is needed), and calling `ElicitAsync` is not practical (for example in [stateless](xref:stateless) mode where server-to-client requests are disabled), the server may throw a . This is a specialized error (JSON-RPC error code `-32042`) that signals to the client that one or more URL-mode elicitations must be completed before the original request can be retried. diff --git a/docs/concepts/index.md b/docs/concepts/index.md index 6393d9997..9e5a90f25 100644 --- a/docs/concepts/index.md +++ b/docs/concepts/index.md @@ -18,6 +18,7 @@ Install the SDK and build your first MCP client and server. | [Progress tracking](progress/progress.md) | Learn how to track progress for long-running operations through notification messages. | | [Cancellation](cancellation/cancellation.md) | Learn how to cancel in-flight MCP requests using cancellation tokens and notifications. | | [Tasks](tasks/tasks.md) | Learn how to use task-based execution for long-running operations that can be polled for status and results. | +| [Multi Round-Trip Requests (MRTR)](mrtr/mrtr.md) | Learn how servers request client input during tool execution using input-required results and retries. | ### Client Features diff --git a/docs/concepts/mrtr/mrtr.md b/docs/concepts/mrtr/mrtr.md new file mode 100644 index 000000000..1d1ebce32 --- /dev/null +++ b/docs/concepts/mrtr/mrtr.md @@ -0,0 +1,291 @@ +--- +title: Multi Round-Trip Requests (MRTR) +author: halter73 +description: How servers request client input during tool execution using Multi Round-Trip Requests. +uid: mrtr +--- + +# Multi Round-Trip Requests (MRTR) + + +> [!WARNING] +> MRTR is part of the **`DRAFT-2026-v1`** revision of the MCP specification ([SEP-2322](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2322)). The wire format and API surface may change before the revision is ratified. See the [Experimental APIs](../../experimental.md) documentation for details on working with experimental APIs. + +Multi Round-Trip Requests (MRTR) let a server tool request input from the client — such as [elicitation](xref:elicitation), [sampling](xref:sampling), or [roots](xref:roots) — as part of a single tool call, without requiring a separate server-to-client JSON-RPC request for each interaction. Instead of returning a final result, the server returns an **incomplete result** containing one or more input requests. The client fulfills those requests and retries the original tool call with the responses attached. + +## Overview + +MRTR is useful when: + +- A tool needs user confirmation before proceeding (elicitation). +- A tool needs LLM reasoning from the client (sampling). +- A tool needs an updated list of client roots. +- A tool needs to perform multiple rounds of interaction in a single logical operation. +- A stateless server needs to orchestrate multi-step flows without keeping handler state in memory between rounds. + +## How MRTR works + +1. The client calls a tool on the server via `tools/call`. +2. The server tool determines it needs client input and returns an `InputRequiredResult` containing `inputRequests` and/or `requestState`. +3. The client resolves each input request (for example by prompting the user for elicitation, calling an LLM for sampling, or listing its roots). +4. The client retries the original `tools/call` with `inputResponses` (keyed to the input requests) and `requestState` echoed back. +5. The server processes the responses and either returns a final result or another `InputRequiredResult` for additional rounds. + +## Opting in + +MRTR activates when both peers negotiate protocol revision **`DRAFT-2026-v1`** during `initialize`. The C# SDK opts in by listing `DRAFT-2026-v1` as a supported protocol version on the client; servers automatically accept it when offered. No experimental flags are required. + +```csharp +// Client +var clientOptions = new McpClientOptions +{ + ProtocolVersion = "DRAFT-2026-v1", + Handlers = new McpClientHandlers + { + ElicitationHandler = HandleElicitationAsync, + SamplingHandler = HandleSamplingAsync, + } +}; +``` + +Under `DRAFT-2026-v1`, MRTR is the recommended way to obtain client input from a server handler. The spec removes the legacy server-to-client `elicitation/create`, `sampling/createMessage`, and `roots/list` request methods, so any code that needs to work on a `DRAFT-2026-v1` Streamable HTTP server (which will be stateless-only in a future revision) must use `InputRequiredException` rather than , , or . The legacy methods still work on stateful sessions — that's how stdio servers keep working under draft today — but they throw `InvalidOperationException("X is not supported in stateless mode.")` on any stateless session, current or draft. + +Under the current protocol revision (`2025-06-18` and earlier), `InputRequiredException` is still supported in stateful sessions via a backward-compatibility resolver — see [Compatibility](#compatibility) below. + +## Authoring an MRTR tool + +A tool participates in MRTR by throwing with an describing what it needs. On retry, the client's responses arrive on the request parameters and the tool inspects them to decide what to do next. + +### Checking MRTR support + +Tools should check before throwing `InputRequiredException`. It returns `true` when either: + +- The negotiated protocol revision is `DRAFT-2026-v1` (MRTR is native), or +- The session is stateful under the current protocol (the SDK can resolve input requests via legacy JSON-RPC and retry the handler). + +```csharp +[McpServerTool, Description("A tool that uses MRTR")] +public static string MyTool( + McpServer server, + RequestContext context) +{ + if (!server.IsMrtrSupported) + { + return "This tool requires a client that negotiates DRAFT-2026-v1, " + + "or a stateful current-protocol session."; + } + + // ... MRTR logic +} +``` + +### Returning an incomplete result + +Throw to return an incomplete result. The exception carries an containing `inputRequests` and/or `requestState`: + +```csharp +[McpServerTool, Description("Tool managing its own MRTR flow")] +public static string AnswerTool( + McpServer server, + RequestContext context, + [Description("The user's question")] string question) +{ + var requestState = context.Params!.RequestState; + var inputResponses = context.Params!.InputResponses; + + // On retry, process the client's responses + if (requestState is not null && inputResponses is not null) + { + var elicitResult = inputResponses["user_answer"].Deserialize(InputResponse.ElicitResultJsonTypeInfo); + return $"You answered: {elicitResult?.Content?.FirstOrDefault().Value}"; + } + + if (!server.IsMrtrSupported) + { + return "MRTR is not supported by this client."; + } + + // First call — request user input + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_answer"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"Please answer: {question}", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["answer"] = new ElicitRequestParams.StringSchema + { + Description = "Your answer" + } + } + } + }) + }, + requestState: "awaiting-answer"); +} +``` + +### Accessing retry data + +When the client retries a tool call, the retry data is available on the request parameters: + +- — a dictionary of client responses keyed by the same keys used in `inputRequests`. +- — the opaque state string echoed back by the client. + +Use with the `JsonTypeInfo` matching the response type. The expected type follows from the matching in the original `inputRequests` map — there is no on-the-wire discriminator. + +- Elicitation — `response.Deserialize(InputResponse.ElicitResultJsonTypeInfo)` +- Sampling — `response.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)` +- Roots list — `response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)` + +### Load shedding with requestState-only responses + +A server can return a `requestState`-only incomplete result (without any `inputRequests`) to defer processing. This is useful for load shedding or breaking up long-running work across multiple requests: + +```csharp +[McpServerTool, Description("Tool that defers work using requestState")] +public static string DeferredTool( + McpServer server, + RequestContext context) +{ + var requestState = context.Params!.RequestState; + + if (requestState is not null) + { + // Resume deferred work + var state = JsonSerializer.Deserialize( + Convert.FromBase64String(requestState)); + return $"Completed step {state!.Step}"; + } + + if (!server.IsMrtrSupported) + { + return "MRTR is not supported by this client."; + } + + // Defer work to a later retry + var initialState = new MyState { Step = 1 }; + throw new InputRequiredException( + requestState: Convert.ToBase64String( + JsonSerializer.SerializeToUtf8Bytes(initialState))); +} +``` + +The client automatically retries `requestState`-only incomplete results, echoing the state back without needing to resolve any input requests. + +### Multiple round trips + +A tool can perform multiple rounds of interaction by throwing `InputRequiredException` multiple times across retries. Use `requestState` to track which round you're on: + +```csharp +[McpServerTool, Description("Multi-step wizard")] +public static string WizardTool( + McpServer server, + RequestContext context) +{ + var requestState = context.Params!.RequestState; + var inputResponses = context.Params!.InputResponses; + + if (requestState == "step-2" && inputResponses is not null) + { + var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + var age = inputResponses["age"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + return $"Welcome, {name}! You are {age} years old."; + } + + if (requestState == "step-1" && inputResponses is not null) + { + var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + + // Second round — ask for age + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["age"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"Hi {name}! How old are you?", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["age"] = new ElicitRequestParams.NumberSchema + { + Description = "Your age" + } + } + } + }) + }, + requestState: "step-2"); + } + + if (!server.IsMrtrSupported) + { + return "MRTR is not supported. Please use a compatible client."; + } + + // First round — ask for name + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What's your name?", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema + { + Description = "Your name" + } + } + } + }) + }, + requestState: "step-1"); +} +``` + +### Providing custom error messages + +When MRTR is not supported, you can provide domain-specific guidance: + +```csharp +if (!server.IsMrtrSupported) +{ + return "This tool requires interactive input. To use it:\n" + + "1. Connect with a client that negotiates MCP protocol revision DRAFT-2026-v1, or\n" + + "2. Use a stateful current-protocol session so the server can resolve the input requests for you.\n" + + "\nStateless current-protocol sessions cannot resolve MRTR input requests."; +} +``` + +## Compatibility + +The SDK supports `InputRequiredException` across two protocol revisions and two session modes: + +| Negotiated protocol | Session mode | Behavior | +|---|---|---| +| `DRAFT-2026-v1` | Stateful | Native MRTR — `InputRequiredResult` is serialized directly to the wire. | +| `DRAFT-2026-v1` | Stateless | Native MRTR — `InputRequiredResult` is serialized directly to the wire. No server-side handler state needed. | +| Current (`2025-06-18` and earlier) | Stateful | Backward-compatibility resolver — the SDK sends standard `elicitation/create` / `sampling/createMessage` / `roots/list` JSON-RPC requests to the client, collects the responses, and retries the handler with `inputResponses` populated. Up to 10 retry rounds. | +| Current (`2025-06-18` and earlier) | Stateless | **Not supported** — `InputRequiredException` raises an `McpException`. The client doesn't speak MRTR, and the server can't resolve input requests via JSON-RPC without a persistent session. | + +> [!NOTE] +> The backcompat resolver is intentionally limited to 10 retry rounds. Tools that need more rounds should require `DRAFT-2026-v1` (check `IsMrtrSupported`). + +### Why `ElicitAsync` / `SampleAsync` / `RequestRootsAsync` throw on stateless servers + +`ElicitAsync` / `SampleAsync` / `RequestRootsAsync` issue a JSON-RPC request to the client and wait for the response on the same session. Stateless servers don't have a persistent session to wait on, so the SDK fails fast with `InvalidOperationException("X is not supported in stateless mode.")` (the check is `McpServer.ClientCapabilities is null`, which is the SDK's proxy for stateless). + +Under the current protocol revision (`2025-06-18` and earlier), stdio and stateful Streamable HTTP keep `ClientCapabilities` populated, so the legacy methods work normally and remain the recommended way to do one-shot client interactions. Under `DRAFT-2026-v1`, the spec removes those request methods from Streamable HTTP entirely; the SDK still allows the legacy methods on draft stdio sessions because stdio is implicitly single-process / stateful and the client handler is wired up regardless of negotiated revision. `InputRequiredException` is the way to write tools that work on every supported configuration. + +### Future direction + +The `DRAFT-2026-v1` revision is moving toward a stateless-only model: `Mcp-Session-Id` is being removed, and Streamable HTTP servers will run statelessly by default under the draft revision. When that lands, the `Stateful` row for `DRAFT-2026-v1` in the compatibility matrix above collapses into the `Stateless` row (Streamable HTTP under draft becomes stateless-only), and `InputRequiredException` becomes uniformly required for non-stdio servers. The current-protocol resolver path will remain for backward compatibility with older clients and stateful servers. + +This work is a follow-up to the present PR. diff --git a/docs/concepts/roots/roots.md b/docs/concepts/roots/roots.md index 7c09e53ad..213d317c0 100644 --- a/docs/concepts/roots/roots.md +++ b/docs/concepts/roots/roots.md @@ -103,3 +103,43 @@ server.RegisterNotificationHandler( Console.WriteLine($"Roots updated. {result.Roots.Count} roots available."); }); ``` + +### Multi Round-Trip Requests (MRTR) + +[MRTR](xref:mrtr) is the SEP-2322 mechanism for server-driven input requests, finalized in protocol revision `DRAFT-2026-v1`. Under the draft protocol, the server-to-client `roots/list` request method is removed; the recommended way to ask the client for its roots from a server handler is to throw and let the SDK emit an on the wire. + +> [!IMPORTANT] +> `RequestRootsAsync` throws `InvalidOperationException("Roots are not supported in stateless mode.")` whenever the server is running stateless — which includes every Streamable HTTP server under `DRAFT-2026-v1` once that revision is forced to stateless-only in a future PR. Stdio servers and current-protocol stateful Streamable HTTP servers continue to work via the legacy server-to-client `roots/list` request flow. For code that needs to run on stateless servers — including all `DRAFT-2026-v1` Streamable HTTP servers going forward — throw `InputRequiredException` from your handler instead. It works under both protocols and both session modes. + +For example: + +```csharp +[McpServerTool, Description("Tool that requests roots via MRTR")] +public static string ListRootsWithMrtr( + McpServer server, + RequestContext context) +{ + // On retry, process the client's roots response + if (context.Params!.InputResponses?.TryGetValue("get_roots", out var response) is true) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots ?? []; + return $"Found {roots.Count} roots: {string.Join(", ", roots.Select(r => r.Uri))}"; + } + + if (!server.IsMrtrSupported) + { + return "This tool requires MRTR support (DRAFT-2026-v1, or a stateful current-protocol session)."; + } + + // First call — request the client's root list + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["get_roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "awaiting-roots"); +} +``` + +> [!TIP] +> See [Multi Round-Trip Requests (MRTR)](xref:mrtr) for the full protocol details, including load shedding, multiple round trips, and the compatibility matrix. diff --git a/docs/concepts/sampling/sampling.md b/docs/concepts/sampling/sampling.md index 4f14a4ee0..bac6ed5ab 100644 --- a/docs/concepts/sampling/sampling.md +++ b/docs/concepts/sampling/sampling.md @@ -120,3 +120,55 @@ McpClientOptions options = new() ### Capability negotiation Sampling requires the client to advertise the `sampling` capability. This is handled automatically — when a is set, the client includes the sampling capability during initialization. The server can check whether the client supports sampling before calling ; if sampling is not supported, the method throws . + +### Multi Round-Trip Requests (MRTR) + +[MRTR](xref:mrtr) is the SEP-2322 mechanism for server-driven input requests, finalized in protocol revision `DRAFT-2026-v1`. Under the draft protocol, the server-to-client `sampling/createMessage` request method is removed; the recommended way to ask the client to sample from a server handler is to throw and let the SDK emit an on the wire. + +> [!IMPORTANT] +> `SampleAsync` and `AsSamplingChatClient` throw `InvalidOperationException("Sampling is not supported in stateless mode.")` whenever the server is running stateless — which includes every Streamable HTTP server under `DRAFT-2026-v1` once that revision is forced to stateless-only in a future PR. Stdio servers and current-protocol stateful Streamable HTTP servers continue to work via the legacy server-to-client `sampling/createMessage` request flow. For code that needs to run on stateless servers — including all `DRAFT-2026-v1` Streamable HTTP servers going forward — throw `InputRequiredException` from your handler instead. It works under both protocols and both session modes. + +For example: + +```csharp +[McpServerTool, Description("Tool that samples via MRTR")] +public static string SampleWithMrtr( + McpServer server, + RequestContext context) +{ + // On retry, process the client's sampling response + if (context.Params!.InputResponses?.TryGetValue("llm_call", out var response) is true) + { + var text = response.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content + .OfType().FirstOrDefault()?.Text; + return $"LLM said: {text}"; + } + + if (!server.IsMrtrSupported) + { + return "This tool requires MRTR support (DRAFT-2026-v1, or a stateful current-protocol session)."; + } + + // First call — request LLM completion from the client + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["llm_call"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Summarize the data" }] + } + ], + MaxTokens = 256 + }) + }, + requestState: "awaiting-sample"); +} +``` + +> [!TIP] +> See [Multi Round-Trip Requests (MRTR)](xref:mrtr) for the full protocol details, including load shedding, multiple round trips, and the compatibility matrix. diff --git a/docs/concepts/toc.yml b/docs/concepts/toc.yml index bd5474338..e0708cb75 100644 --- a/docs/concepts/toc.yml +++ b/docs/concepts/toc.yml @@ -19,6 +19,8 @@ items: uid: cancellation - name: Tasks uid: tasks + - name: Multi Round-Trip Requests (MRTR) + uid: mrtr - name: Client Features items: - name: Sampling diff --git a/src/Common/Experimentals.cs b/src/Common/Experimentals.cs index 7e7e969bb..e356480ed 100644 --- a/src/Common/Experimentals.cs +++ b/src/Common/Experimentals.cs @@ -110,4 +110,23 @@ internal static class Experimentals /// URL for the experimental RunSessionHandler API. /// public const string RunSessionHandler_Url = "https://github.com/modelcontextprotocol/csharp-sdk/blob/main/docs/list-of-diagnostics.md#mcpexp002"; + + /// + /// Diagnostic ID for the experimental Multi Round-Trip Requests (MRTR) feature. + /// + /// + /// This uses the same diagnostic ID as because MRTR + /// is an experimental feature in the MCP specification (SEP-2322). + /// + public const string Mrtr_DiagnosticId = "MCPEXP001"; + + /// + /// Message for the experimental MRTR feature. + /// + public const string Mrtr_Message = "The Multi Round-Trip Requests (MRTR) feature is experimental per the MCP specification (SEP-2322) and is subject to change."; + + /// + /// URL for the experimental MRTR feature. + /// + public const string Mrtr_Url = "https://github.com/modelcontextprotocol/csharp-sdk/blob/main/docs/list-of-diagnostics.md#mcpexp001"; } diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj index 762091667..44bac1bc9 100644 --- a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj +++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj @@ -10,6 +10,7 @@ ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK. README.md true + $(NoWarn);MCPEXP001 diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 9763a6cf5..a489cd9e6 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -508,12 +508,25 @@ internal static string MakeNewSessionId() // Implementation for reading a JSON-RPC message from the request body var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); - if (context.User?.Identity?.IsAuthenticated == true && message is not null) + if (message is not null) { - message.Context = new() + var protocolVersion = context.Request.Headers[McpProtocolVersionHeaderName].ToString(); + var isAuthenticated = context.User?.Identity?.IsAuthenticated == true; + + if (isAuthenticated || !string.IsNullOrEmpty(protocolVersion)) { - User = context.User, - }; + message.Context ??= new(); + + if (isAuthenticated) + { + message.Context.User = context.User; + } + + if (!string.IsNullOrEmpty(protocolVersion)) + { + message.Context.ProtocolVersion = protocolVersion; + } + } } return message; @@ -812,7 +825,7 @@ private static bool ValuesMatch(string? actual, string? expected, System.Text.Js // JSON Schema defines two numeric types: "number" (any numeric value including // decimals like 3.14) and "integer" (whole numbers only like 42). Both produce // JsonValueKind.Number in the JSON body and are sent as numeric strings in headers. - // We check for both because different SDKs may serialize them differently — + // We check for both because different SDKs may serialize them differently - // e.g., a client might send header "42.0" for an "integer" body value of 42, // or header "42" for a "number" body value of 42.0. Without handling both types, // valid cross-SDK requests would be incorrectly rejected. diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 66410b272..e6ab3aae4 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Protocol; using System.Collections.Concurrent; using System.Text.Json; +using System.Text.Json.Nodes; namespace ModelContextProtocol.Client; @@ -142,6 +143,8 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not RequestMethods.SamplingCreateMessage, async (request, jsonRpcRequest, cancellationToken) => { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.SamplingCreateMessage); + // Check if this is a task-augmented request if (request?.Task is { } taskMetadata) { @@ -176,10 +179,14 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not { requestHandlers.Set( RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), + (request, _, cancellationToken) => + { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.SamplingCreateMessage); + return samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken); + }, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, McpJsonUtilities.JsonContext.Default.CreateMessageResult); } @@ -192,7 +199,11 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not { requestHandlers.Set( RequestMethods.RootsList, - (request, _, cancellationToken) => rootsHandler(request, cancellationToken), + (request, _, cancellationToken) => + { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.RootsList); + return rootsHandler(request, cancellationToken); + }, McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, McpJsonUtilities.JsonContext.Default.ListRootsResult); @@ -209,6 +220,8 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not RequestMethods.ElicitationCreate, async (request, jsonRpcRequest, cancellationToken) => { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.ElicitationCreate); + // Check if this is a task-augmented request if (request?.Task is { } taskMetadata) { @@ -241,6 +254,7 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not RequestMethods.ElicitationCreate, async (request, _, cancellationToken) => { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.ElicitationCreate); var result = await elicitationHandler(request, cancellationToken).ConfigureAwait(false); return ElicitResult.WithDefaults(request, result); }, @@ -547,6 +561,98 @@ private void RegisterTaskHandlers(RequestHandlers requestHandlers, IMcpTaskStore /// public override Task Completion => _sessionHandler.CompletionTask; + /// + private async ValueTask> ResolveInputRequestsAsync( + IDictionary inputRequests, + CancellationToken cancellationToken) + { + // Resolve all input requests concurrently. If any fails, cancel the rest so user-facing + // handlers (sampling/elicitation prompts) don't keep running for a request whose caller + // has already given up, and ensure exceptions from late-completing tasks are observed. + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + var keyed = new (string Key, Task Task)[inputRequests.Count]; + int i = 0; + foreach (var kvp in inputRequests) + { + keyed[i++] = (kvp.Key, ResolveInputRequestAsync(kvp.Value, linkedCts.Token)); + } + + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + linkedCts.Cancel(); + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + // Observed; the original exception is the one we want to surface. + } + throw; + } + + var responses = new Dictionary(keyed.Length); + foreach (var (key, task) in keyed) + { + responses[key] = task.Result; + } + return responses; + } + + private async Task ResolveInputRequestAsync(InputRequest inputRequest, CancellationToken cancellationToken) + { + switch (inputRequest.Method) + { + case RequestMethods.SamplingCreateMessage: + if (_options.Handlers.SamplingHandler is { } samplingHandler) + { + var samplingParams = inputRequest.SamplingParams + ?? throw new McpException($"Failed to deserialize sampling parameters from MRTR input request."); + var result = await samplingHandler( + samplingParams, + samplingParams.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken).ConfigureAwait(false); + return InputResponse.FromSamplingResult(result); + } + + throw new InvalidOperationException( + $"Server sent a sampling input request, but no {nameof(McpClientHandlers.SamplingHandler)} is registered."); + + case RequestMethods.ElicitationCreate: + if (_options.Handlers.ElicitationHandler is { } elicitationHandler) + { + var elicitParams = inputRequest.ElicitationParams + ?? throw new McpException($"Failed to deserialize elicitation parameters from MRTR input request."); + var result = await elicitationHandler(elicitParams, cancellationToken).ConfigureAwait(false); + result = ElicitResult.WithDefaults(elicitParams, result); + return InputResponse.FromElicitResult(result); + } + + throw new InvalidOperationException( + $"Server sent an elicitation input request, but no {nameof(McpClientHandlers.ElicitationHandler)} is registered."); + + case RequestMethods.RootsList: + if (_options.Handlers.RootsHandler is { } rootsHandler) + { + // ListRootsRequest params are optional per the spec, so fall back to an empty params instance. + var rootsParams = inputRequest.RootsParams ?? new ListRootsRequestParams(); + var result = await rootsHandler(rootsParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromRootsResult(result); + } + + throw new InvalidOperationException( + $"Server sent a roots list input request, but no {nameof(McpClientHandlers.RootsHandler)} is registered."); + + default: + throw new NotSupportedException($"Unsupported input request method: '{inputRequest.Method}'."); + } + } + /// /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. /// @@ -718,13 +824,13 @@ public override void ClearKnownTools() } /// - public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + public override async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { // For tools/call requests, attach the cached tool definition to the message context // so the transport can add custom Mcp-Param-* headers based on x-mcp-header schema annotations. if (request.Method == RequestMethods.ToolsCall && - request.Params is System.Text.Json.Nodes.JsonObject paramsObj && - paramsObj.TryGetPropertyValue("name", out var nameNode) && + request.Params is System.Text.Json.Nodes.JsonObject paramsObjForHeaders && + paramsObjForHeaders.TryGetPropertyValue("name", out var nameNode) && nameNode?.GetValue() is { } toolName) { if (_toolCache.TryGetValue(toolName, out var tool)) @@ -739,7 +845,67 @@ request.Params is System.Text.Json.Nodes.JsonObject paramsObj && } } - return _sessionHandler.SendRequestAsync(request, cancellationToken); + const int maxRetries = 10; + + for (int attempt = 0; attempt <= maxRetries; attempt++) + { + JsonRpcResponse response = await _sessionHandler.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + + // Check if the result is an InputRequiredResult by looking at result_type. + if (response.Result is JsonObject resultObj && + resultObj.TryGetPropertyValue("resultType", out var resultTypeNode) && + resultTypeNode?.GetValue() is "input_required") + { + WarnIfInputRequiredResultOnNonMrtrSession(request.Method); + + var inputRequiredResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.InputRequiredResult) + ?? throw new JsonException("Failed to deserialize InputRequiredResult."); + + if (inputRequiredResult.InputRequests is { Count: > 0 } inputRequests) + { + IDictionary inputResponses = + await ResolveInputRequestsAsync(inputRequests, cancellationToken).ConfigureAwait(false); + + // Clone the original request params and add inputResponses + requestState for the retry. + var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject(); + + paramsObj["inputResponses"] = JsonSerializer.SerializeToNode( + inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse); + + if (inputRequiredResult.RequestState is { } requestState) + { + paramsObj["requestState"] = requestState; + } + else + { + // Strip any stale requestState carried over from the previous round's clone so + // the server doesn't see a continuation token the current round is not using. + paramsObj.Remove("requestState"); + } + + request = new JsonRpcRequest { Method = request.Method, Params = paramsObj, Context = request.Context }; + } + else if (inputRequiredResult.RequestState is not null) + { + // No input requests but has requestState (e.g., load shedding) - just retry with state. + var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject(); + paramsObj["requestState"] = inputRequiredResult.RequestState; + paramsObj.Remove("inputResponses"); + + request = new JsonRpcRequest { Method = request.Method, Params = paramsObj, Context = request.Context }; + } + else + { + throw new McpException("Server returned an InputRequiredResult without inputRequests or requestState."); + } + + continue; // retry with the updated request + } + + return response; + } + + throw new McpException($"Server returned InputRequiredResult more than {maxRetries} times."); } /// @@ -775,6 +941,30 @@ public override async ValueTask DisposeAsync() await Completion.ConfigureAwait(false); } + /// Logs a warning if the session negotiated MRTR but the server sent a legacy JSON-RPC request. + private void WarnIfLegacyRequestOnMrtrSession(string method) + { + if (_negotiatedProtocolVersion == McpSessionHandler.DraftProtocolVersion) + { + LogLegacyRequestOnMrtrSession(_endpointName, method); + } + } + + /// Logs a warning if the session did not negotiate MRTR but the server sent an InputRequiredResult. + private void WarnIfInputRequiredResultOnNonMrtrSession(string method) + { + if (_negotiatedProtocolVersion != McpSessionHandler.DraftProtocolVersion) + { + LogInputRequiredResultOnNonMrtrSession(_endpointName, method, _negotiatedProtocolVersion); + } + } + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received legacy '{Method}' JSON-RPC request on session that negotiated MRTR. The server should use InputRequiredResult instead of sending direct requests.")] + private partial void LogLegacyRequestOnMrtrSession(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received InputRequiredResult for '{Method}' on session that did not negotiate MRTR (protocol version '{ProtocolVersion}'). The server may not be spec-compliant.")] + private partial void LogInputRequiredResultOnNonMrtrSession(string endpointName, string method, string? protocolVersion); + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); @@ -798,5 +988,4 @@ public override async ValueTask DisposeAsync() [LoggerMessage(Level = LogLevel.Warning, Message = "Tool '{ToolName}' excluded from tools/list: {Reason}")] private partial void LogToolRejected(string toolName, string reason); - } diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index abb6d29df..b4613d9f2 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; using System.Diagnostics.CodeAnalysis; @@ -144,6 +144,13 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(SubscribeRequestParams))] [JsonSerializable(typeof(UnsubscribeRequestParams))] + // MCP MRTR (Multi Round-Trip Requests) + [JsonSerializable(typeof(InputRequiredResult))] + [JsonSerializable(typeof(InputRequest))] + [JsonSerializable(typeof(InputResponse))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(IDictionary))] + // MCP Task Request Params / Results [JsonSerializable(typeof(McpTask))] [JsonSerializable(typeof(McpTaskStatus))] diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index 4201f9833..73d99da71 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -68,7 +68,7 @@ public abstract partial class McpSession : IAsyncDisposable /// /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous send operation. - /// The transport is not connected. + /// The transport is not connected, or is a . Use for requests. /// is . /// /// diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 77c18b8be..e874e6724 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -31,6 +31,13 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable /// The latest version of the protocol supported by this implementation. internal const string LatestProtocolVersion = "2025-11-25"; + /// + /// The draft protocol version that enables MRTR (Multi Round-Trip Requests) per SEP-2322. + /// Clients and servers opt in by setting + /// or to this value. + /// + internal const string DraftProtocolVersion = "DRAFT-2026-v1"; + /// /// All protocol versions supported by this implementation. /// Keep in sync with s_supportedProtocolVersions in StreamableHttpHandler. @@ -41,7 +48,7 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable "2025-03-26", "2025-06-18", LatestProtocolVersion, - "DRAFT-2026-v1", + DraftProtocolVersion, ]; /// @@ -642,6 +649,13 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can { Throw.IfNull(message); + if (message is JsonRpcRequest request) + { + throw new InvalidOperationException( + $"Cannot send '{request.Method}' request via {nameof(SendMessageAsync)}. " + + $"Use {nameof(SendRequestAsync)} instead to get a correlated response."); + } + cancellationToken.ThrowIfCancellationRequested(); Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; diff --git a/src/ModelContextProtocol.Core/Protocol/InputRequest.cs b/src/ModelContextProtocol.Core/Protocol/InputRequest.cs new file mode 100644 index 000000000..bd9161423 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputRequest.cs @@ -0,0 +1,197 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents a server-initiated request that the client must fulfill as part of an MRTR +/// (Multi Round-Trip Request) flow. +/// +/// +/// +/// An wraps a server-to-client request such as +/// , , +/// or . It is included in an +/// when the server needs additional input before it can complete a client-initiated request. +/// +/// +/// The property identifies the type of request, and the corresponding +/// parameters can be accessed via the typed accessor properties. +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +[JsonConverter(typeof(Converter))] +public sealed class InputRequest +{ + /// + /// Gets or sets the method name identifying the type of this input request. + /// + /// + /// Standard values include: + /// + /// A sampling request. + /// An elicitation request. + /// A roots list request. + /// + /// + [JsonPropertyName("method")] + public required string Method { get; set; } + + /// + /// Gets or sets the raw JSON parameters for this input request. + /// + /// + /// Use the typed accessor properties (, , + /// ) for convenient strongly-typed access. + /// + [JsonPropertyName("params")] + public JsonElement? Params { get; set; } + + /// + /// Gets the parameters as when + /// is . + /// + /// The deserialized sampling parameters, or if the method does not match or params are absent. + [JsonIgnore] + public CreateMessageRequestParams? SamplingParams => + string.Equals(Method, RequestMethods.SamplingCreateMessage, StringComparison.Ordinal) && Params is { } p + ? JsonSerializer.Deserialize(p, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams) + : null; + + /// + /// Gets the parameters as when + /// is . + /// + /// The deserialized elicitation parameters, or if the method does not match or params are absent. + [JsonIgnore] + public ElicitRequestParams? ElicitationParams => + string.Equals(Method, RequestMethods.ElicitationCreate, StringComparison.Ordinal) && Params is { } p + ? JsonSerializer.Deserialize(p, McpJsonUtilities.JsonContext.Default.ElicitRequestParams) + : null; + + /// + /// Gets the parameters as when + /// is . + /// + /// The deserialized roots list parameters, or if the method does not match or params are absent. + [JsonIgnore] + public ListRootsRequestParams? RootsParams => + string.Equals(Method, RequestMethods.RootsList, StringComparison.Ordinal) && Params is { } p + ? JsonSerializer.Deserialize(p, McpJsonUtilities.JsonContext.Default.ListRootsRequestParams) + : null; + + /// + /// Creates an for a sampling request. + /// + /// The sampling request parameters. + /// A new instance. + public static InputRequest ForSampling(CreateMessageRequestParams requestParams) + { + Throw.IfNull(requestParams); + return new() + { + Method = RequestMethods.SamplingCreateMessage, + Params = JsonSerializer.SerializeToElement(requestParams, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams), + }; + } + + /// + /// Creates an for an elicitation request. + /// + /// The elicitation request parameters. + /// A new instance. + public static InputRequest ForElicitation(ElicitRequestParams requestParams) + { + Throw.IfNull(requestParams); + return new() + { + Method = RequestMethods.ElicitationCreate, + Params = JsonSerializer.SerializeToElement(requestParams, McpJsonUtilities.JsonContext.Default.ElicitRequestParams), + }; + } + + /// + /// Creates an for a roots list request. + /// + /// The roots list request parameters. + /// A new instance. + public static InputRequest ForRootsList(ListRootsRequestParams requestParams) + { + Throw.IfNull(requestParams); + return new() + { + Method = RequestMethods.RootsList, + Params = JsonSerializer.SerializeToElement(requestParams, McpJsonUtilities.JsonContext.Default.ListRootsRequestParams), + }; + } + + /// Provides JSON serialization support for . + public sealed class Converter : JsonConverter + { + /// + public override InputRequest? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected StartObject token."); + } + + string? method = null; + JsonElement? parameters = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected PropertyName token."); + } + + string propertyName = reader.GetString()!; + reader.Read(); + + switch (propertyName) + { + case "method": + method = reader.GetString(); + break; + case "params": + parameters = JsonElement.ParseValue(ref reader); + break; + default: + reader.Skip(); + break; + } + } + + if (method is null) + { + throw new JsonException("InputRequest must have a 'method' property."); + } + + return new InputRequest + { + Method = method, + Params = parameters, + }; + } + + /// + public override void Write(Utf8JsonWriter writer, InputRequest value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + writer.WriteString("method", value.Method); + if (value.Params is { } p) + { + writer.WritePropertyName("params"); + p.WriteTo(writer); + } + writer.WriteEndObject(); + } + } +} diff --git a/src/ModelContextProtocol.Core/Protocol/InputRequiredException.cs b/src/ModelContextProtocol.Core/Protocol/InputRequiredException.cs new file mode 100644 index 000000000..4f39b17a5 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputRequiredException.cs @@ -0,0 +1,109 @@ +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Protocol; + +/// +/// The exception that is thrown by a server handler to return an +/// to the client, signaling that additional input is needed before the request can be completed. +/// +/// +/// +/// This exception is part of the Multi Round-Trip Requests (MRTR) API. Tool handlers +/// throw this exception to directly control the input-required result payload, including +/// and . +/// +/// +/// For stateless servers, this enables multi-round-trip flows without requiring the handler to stay +/// alive between round trips. The server encodes its state in +/// and receives it back on retry via . +/// +/// +/// To return a requestState-only response (e.g., for load shedding), omit +/// and set only . +/// The client will retry the request with the state echoed back. +/// +/// +/// This exception can only be used when MRTR is supported by the client. Check +/// before throwing. If thrown when MRTR is not +/// supported, the exception will propagate as a JSON-RPC internal error. +/// +/// +/// +/// +/// [McpServerTool, Description("A stateless tool using MRTR")] +/// public static string MyTool(McpServer server, RequestContext<CallToolRequestParams> context) +/// { +/// if (context.Params.RequestState is { } state) +/// { +/// // Retry: process accumulated state and input responses +/// var responses = context.Params.InputResponses; +/// return "Final result"; +/// } +/// +/// if (!server.IsMrtrSupported) +/// { +/// return "This tool requires MRTR support."; +/// } +/// +/// throw new InputRequiredException( +/// inputRequests: new Dictionary<string, InputRequest> +/// { +/// ["user_input"] = InputRequest.ForElicitation(new ElicitRequestParams { ... }) +/// }, +/// requestState: "encoded-state"); +/// } +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +public class InputRequiredException : Exception +{ + /// + /// Initializes a new instance of the class + /// with the specified . + /// + /// The input-required result to return to the client. + public InputRequiredException(InputRequiredResult result) + : base("The server returned an input-required result requiring additional client input.") + { + Throw.IfNull(result); + Result = result; + } + + /// + /// Initializes a new instance of the class + /// with the specified input requests and/or request state. + /// + /// + /// Server-initiated requests that the client must fulfill before retrying. + /// Keys are server-assigned identifiers. + /// + /// + /// Opaque state to be echoed back by the client when retrying. The client must + /// treat this as an opaque blob and must not inspect or modify it. + /// + /// + /// Both and are . + /// At least one must be provided. + /// + public InputRequiredException( + IDictionary? inputRequests = null, + string? requestState = null) + : base("The server returned an input-required result requiring additional client input.") + { + if (inputRequests is null && requestState is null) + { + throw new ArgumentException("At least one of inputRequests or requestState must be provided."); + } + + Result = new InputRequiredResult + { + InputRequests = inputRequests, + RequestState = requestState, + }; + } + + /// + /// Gets the input-required result to return to the client. + /// + public InputRequiredResult Result { get; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs b/src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs new file mode 100644 index 000000000..b02bfa9b2 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs @@ -0,0 +1,65 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents an input-required result sent by the server to indicate that additional input is needed +/// before the request can be completed. +/// +/// +/// +/// An is returned in response to a client-initiated request when +/// the server needs the client to fulfill one or more server-initiated requests before it can produce +/// a final result. Per SEP-2322 the wire format is valid for , +/// , and resources/read; this SDK wires the MRTR +/// interceptor into all three methods. +/// +/// +/// At least one of or must be present. +/// +/// +/// This type is part of the Multi Round-Trip Requests (MRTR) mechanism defined in SEP-2322. +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +public sealed class InputRequiredResult : Result +{ + /// + /// Initializes a new instance of the class. + /// + public InputRequiredResult() + { + ResultType = "input_required"; + } + + /// + /// Gets or sets the server-initiated requests that the client must fulfill before retrying the original request. + /// + /// + /// + /// The keys are server-assigned identifiers. The client must include a response for each key in the + /// map when retrying the original request. + /// + /// + [JsonPropertyName("inputRequests")] + public IDictionary? InputRequests { get; set; } + + /// + /// Gets or sets opaque state to be echoed back by the client when retrying the original request. + /// + /// + /// + /// The client must treat this as an opaque blob and must not inspect, parse, modify, or make + /// any assumptions about the contents. If present, the client must include this value in the + /// property when retrying the original request. + /// + /// + /// Servers may encode request state in any format (e.g., plain JSON, base64-encoded JSON, + /// encrypted JWT, serialized binary). If the state contains sensitive data, servers should + /// encrypt it to ensure confidentiality and integrity. + /// + /// + [JsonPropertyName("requestState")] + public string? RequestState { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/InputResponse.cs b/src/ModelContextProtocol.Core/Protocol/InputResponse.cs new file mode 100644 index 000000000..465ea3235 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputResponse.cs @@ -0,0 +1,128 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents a client's response to a server-initiated as part of an MRTR +/// (Multi Round-Trip Request) flow. +/// +/// +/// +/// An wraps the result of a server-to-client request such as +/// , , or . +/// The type of the inner response corresponds to the of the +/// associated input request. +/// +/// +/// The input response does not carry its own type discriminator in JSON. The type is determined by +/// the corresponding key in the map. +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +[JsonConverter(typeof(Converter))] +public sealed class InputResponse +{ + /// + /// Gets or sets the raw JSON element representing the response. + /// + /// + /// Use with the JsonTypeInfo<T> matching the + /// associated - for elicitation, sampling, or roots see + /// , , and + /// . + /// + [JsonIgnore] + public JsonElement RawValue { get; set; } + + /// + /// Deserializes the raw value to the specified result type. + /// + /// The type to deserialize to (e.g., , ). + /// The JSON type information for . + /// The deserialized result, or if deserialization fails. + public T? Deserialize(System.Text.Json.Serialization.Metadata.JsonTypeInfo typeInfo) => + JsonSerializer.Deserialize(RawValue, typeInfo); + + /// + /// Gets the for , suitable for use with + /// when the corresponding is + /// . + /// + public static JsonTypeInfo ElicitResultJsonTypeInfo => McpJsonUtilities.JsonContext.Default.ElicitResult; + + /// + /// Gets the for , suitable for use with + /// when the corresponding is + /// . + /// + public static JsonTypeInfo CreateMessageResultJsonTypeInfo => McpJsonUtilities.JsonContext.Default.CreateMessageResult; + + /// + /// Gets the for , suitable for use with + /// when the corresponding is + /// . + /// + public static JsonTypeInfo ListRootsResultJsonTypeInfo => McpJsonUtilities.JsonContext.Default.ListRootsResult; + + /// + /// Creates an from a . + /// + /// The sampling result. + /// A new instance. + public static InputResponse FromSamplingResult(CreateMessageResult result) + { + Throw.IfNull(result); + return new() + { + RawValue = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CreateMessageResult), + }; + } + + /// + /// Creates an from an . + /// + /// The elicitation result. + /// A new instance. + public static InputResponse FromElicitResult(ElicitResult result) + { + Throw.IfNull(result); + return new() + { + RawValue = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.ElicitResult), + }; + } + + /// + /// Creates an from a . + /// + /// The roots list result. + /// A new instance. + public static InputResponse FromRootsResult(ListRootsResult result) + { + Throw.IfNull(result); + return new() + { + RawValue = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.ListRootsResult), + }; + } + + /// Provides JSON serialization support for . + public sealed class Converter : JsonConverter + { + /// + public override InputResponse? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return new InputResponse { RawValue = element }; + } + + /// + public override void Write(Utf8JsonWriter writer, InputResponse value, JsonSerializerOptions options) + { + value.RawValue.WriteTo(writer); + } + } +} diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs index 2fa9839f0..e5c0f3931 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -74,4 +74,15 @@ public sealed class JsonRpcMessageContext /// /// public IDictionary? Items { get; set; } + + /// + /// Gets or sets the protocol version from the transport-level header (e.g. Mcp-Protocol-Version) + /// that accompanied this JSON-RPC message. + /// + /// + /// In stateless Streamable HTTP mode, the protocol version cannot be negotiated via the initialize + /// handshake because each request creates a new server instance. This property allows the transport layer + /// to flow the protocol version header so the server can determine client capabilities. + /// + public string? ProtocolVersion { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/RequestParams.cs b/src/ModelContextProtocol.Core/Protocol/RequestParams.cs index 0a0586a71..004f1711f 100644 --- a/src/ModelContextProtocol.Core/Protocol/RequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/RequestParams.cs @@ -1,3 +1,4 @@ +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Nodes; using System.Text.Json.Serialization; @@ -25,6 +26,52 @@ private protected RequestParams() [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + /// + /// Gets or sets the responses to server-initiated input requests from a previous . + /// + /// + /// + /// This property is populated when retrying a request after receiving an . + /// Each key corresponds to a key from the map, and + /// the value is the client's response to that input request. + /// + /// + [Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] + [JsonIgnore] + public IDictionary? InputResponses + { + get => InputResponsesCore; + set => InputResponsesCore = value; + } + + // See ExperimentalInternalPropertyTests.cs before modifying this property. + [JsonInclude] + [JsonPropertyName("inputResponses")] + internal IDictionary? InputResponsesCore { get; set; } + + /// + /// Gets or sets opaque request state echoed back from a previous . + /// + /// + /// + /// This property is populated when retrying a request after receiving an + /// that included a value. The client must echo back the + /// exact value without modification. + /// + /// + [Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] + [JsonIgnore] + public string? RequestState + { + get => RequestStateCore; + set => RequestStateCore = value; + } + + // See ExperimentalInternalPropertyTests.cs before modifying this property. + [JsonInclude] + [JsonPropertyName("requestState")] + internal string? RequestStateCore { get; set; } + /// /// Gets the opaque token that will be attached to any subsequent progress notifications. /// diff --git a/src/ModelContextProtocol.Core/Protocol/Result.cs b/src/ModelContextProtocol.Core/Protocol/Result.cs index 58b076ddb..6e43249a1 100644 --- a/src/ModelContextProtocol.Core/Protocol/Result.cs +++ b/src/ModelContextProtocol.Core/Protocol/Result.cs @@ -21,4 +21,18 @@ private protected Result() /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the type of the result, which allows the client to determine how to parse the result object. + /// + /// + /// + /// When absent or set to "complete", the result is a normal completed response. + /// When set to "input_required", the result is an indicating + /// that additional input is needed before the request can be completed. + /// + /// + /// Defaults to , which is equivalent to "complete". + [JsonPropertyName("resultType")] + public string? ResultType { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index 957f58a51..bf87980e5 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -1,5 +1,6 @@ -using ModelContextProtocol.Protocol; -using System.Diagnostics; +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; namespace ModelContextProtocol.Server; @@ -15,6 +16,14 @@ internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport public override IServiceProvider? Services => server.Services; public override LoggingLevel? LoggingLevel => server.LoggingLevel; + /// + /// Gets or sets the MRTR context for the current request, if any. + /// Set by when an MRTR-aware handler invocation is in progress. + /// + internal MrtrContext? ActiveMrtrContext { get; set; } + + public override bool IsMrtrSupported => server.IsMrtrSupported; + public override ValueTask DisposeAsync() => server.DisposeAsync(); public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); @@ -39,6 +48,16 @@ public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { + // When an MRTR context is active, intercept server-to-client requests (sampling, elicitation, roots) + // and route them through the MRTR mechanism instead of sending them over the wire. + // Task-augmented requests (SampleAsTaskAsync/ElicitAsTaskAsync) have a "task" property on their params + // and expect a CreateTaskResult response, so they must bypass MRTR and go over the wire. + if (ActiveMrtrContext is { } mrtrContext && + !(request.Params is JsonObject paramsObj && paramsObj.ContainsKey("task"))) + { + return SendRequestViaMrtrAsync(mrtrContext, request, cancellationToken); + } + if (request.Context is not null) { throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); @@ -51,4 +70,23 @@ public override Task SendRequestAsync(JsonRpcRequest request, C return server.SendRequestAsync(request, cancellationToken); } + + private static async Task SendRequestViaMrtrAsync( + MrtrContext mrtrContext, JsonRpcRequest request, CancellationToken cancellationToken) + { + var inputRequest = new InputRequest + { + Method = request.Method, + Params = request.Params is { } paramsNode + ? JsonSerializer.Deserialize(paramsNode, McpJsonUtilities.JsonContext.Default.JsonElement) + : null, + }; + var inputResponse = await mrtrContext.RequestInputAsync(inputRequest, cancellationToken).ConfigureAwait(false); + + return new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(inputResponse.RawValue, McpJsonUtilities.JsonContext.Default.JsonElement), + }; + } } diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index b8b41bdc3..444365361 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -64,6 +64,25 @@ protected McpServer() /// Gets the last logging level set by the client, or if it's never been set. public abstract LoggingLevel? LoggingLevel { get; } + /// + /// Gets a value indicating whether the connected client supports Multi Round-Trip Requests (MRTR). + /// + /// + /// + /// When this property returns , tool handlers can throw + /// to return an + /// with and/or + /// to the client. + /// + /// + /// When this property returns , tool handlers should provide a fallback + /// experience (for example, returning a text message explaining that the client does not support + /// the required feature) instead of throwing . + /// + /// + [Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] + public virtual bool IsMrtrSupported => false; + /// /// Runs the server, listening for and handling client requests. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index 203856814..c48ee2da0 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -2,8 +2,10 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; @@ -27,6 +29,13 @@ internal sealed partial class McpServerImpl : McpServer private readonly McpSessionHandler _sessionHandler; private readonly SemaphoreSlim _disposeLock = new(1, 1); private readonly McpTaskCancellationTokenProvider? _taskCancellationTokenProvider; + private readonly ConcurrentDictionary _mrtrContinuations = new(); + private readonly ConcurrentDictionary _mrtrContextsByRequestId = new(); + + // Track MRTR handler tasks using the same inFlightCount + TCS pattern as + // McpSessionHandler.ProcessMessagesCoreAsync. Starts at 1 for DisposeAsync itself. + private int _mrtrInFlightCount = 1; + private readonly TaskCompletionSource _allMrtrHandlersCompleted = new(TaskCreationOptions.RunContinuationsAsynchronously); private ClientCapabilities? _clientCapabilities; private Implementation? _clientInfo; @@ -91,6 +100,7 @@ public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFact ConfigureLogging(options); ConfigureCompletion(options); ConfigureExperimentalAndExtensions(options); + ConfigureMrtr(); // Register any notification handlers that were provided. if (options.Handlers.NotificationHandlers is { } notificationHandlers) @@ -210,9 +220,35 @@ public override async ValueTask DisposeAsync() _disposed = true; + // Dispose the session handler first - cancels message processing and waits for all + // in-flight request handlers (including retries in AwaitMrtrHandlerAsync) to complete. + // After this returns, no new requests can be processed and no new MRTR continuations + // can be created, so _mrtrContinuations is effectively frozen. _taskCancellationTokenProvider?.Dispose(); _disposables.ForEach(d => d()); await _sessionHandler.DisposeAsync().ConfigureAwait(false); + + // Cancel all orphaned MRTR handlers still suspended in continuations (waiting for + // retries that will never arrive now that the session handler is disposed). + int cancelledCount = _mrtrContinuations.Count; + foreach (var continuation in _mrtrContinuations.Values) + { + continuation.CancelHandler(); + } + + if (cancelledCount > 0) + { + MrtrContinuationsCancelled(cancelledCount); + } + + // Wait for all MRTR handler tasks to complete using the same inFlightCount + TCS + // pattern as McpSessionHandler.ProcessMessagesCoreAsync. The count started at 1 + // (for DisposeAsync itself); decrementing it here triggers the drain if handlers + // are still in flight. ObserveHandlerCompletionAsync decrements for each handler. + if (Interlocked.Decrement(ref _mrtrInFlightCount) != 0) + { + await _allMrtrHandlersCompleted.Task.ConfigureAwait(false); + } } private void ConfigureInitialize(McpServerOptions options) @@ -231,7 +267,8 @@ private void ConfigureInitialize(McpServerOptions options) // Otherwise, try to use whatever the client requested as long as it's supported. // If it's not supported, fall back to the latest supported version. string? protocolVersion = options.ProtocolVersion; - protocolVersion ??= request?.ProtocolVersion is string clientProtocolVersion && McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? + protocolVersion ??= request?.ProtocolVersion is string clientProtocolVersion && + McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? clientProtocolVersion : McpSessionHandler.LatestProtocolVersion; @@ -725,7 +762,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) McpErrorCode.InvalidParams); } - // Task augmentation requested - return CreateTaskResult + // Task augmentation requested with immediate creation return await ExecuteToolAsTaskAsync(tool, request, taskMetadata, taskStore, sendNotifications, cancellationToken).ConfigureAwait(false); } @@ -774,9 +811,18 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) } catch (Exception e) { - ToolCallError(request.Params?.Name ?? string.Empty, e); + // Skip logging for OperationCanceledException when the cancellation token + // is signaled - tool handler cancellation is an expected lifecycle event + // (client request cancellation, session shutdown, MRTR teardown), not a + // tool error. + // Skip logging for InputRequiredException - it's normal MRTR control flow, + // not an error (tools throw it to signal an InputRequiredResult). + if (!(e is OperationCanceledException && cancellationToken.IsCancellationRequested) && e is not InputRequiredException) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); + } - if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException) + if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException || e is InputRequiredException) { throw; } @@ -990,7 +1036,7 @@ private ValueTask InvokeHandlerAsync( { return _servicesScopePerRequest ? InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest, args), cancellationToken); + handler(new(CreateDestinationBoundServer(jsonRpcRequest), jsonRpcRequest, args), cancellationToken); async ValueTask InvokeScopedAsync( McpRequestHandler handler, @@ -1002,7 +1048,7 @@ async ValueTask InvokeScopedAsync( try { return await handler( - new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest, args) + new RequestContext(CreateDestinationBoundServer(jsonRpcRequest), jsonRpcRequest, args) { Services = scope?.ServiceProvider ?? Services, }, @@ -1018,6 +1064,18 @@ async ValueTask InvokeScopedAsync( } } + private DestinationBoundMcpServer CreateDestinationBoundServer(JsonRpcRequest jsonRpcRequest) + { + var server = new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport); + + if (_mrtrContextsByRequestId.TryRemove(jsonRpcRequest.Id, out var mrtrContext)) + { + server.ActiveMrtrContext = mrtrContext; + } + + return server; + } + private void SetHandler( string method, McpRequestHandler handler, @@ -1106,6 +1164,436 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => _ => Protocol.LoggingLevel.Emergency, }; + /// + /// Checks whether the negotiated protocol version enables MRTR per SEP-2322 (DRAFT-2026-v1). + /// + internal bool ClientSupportsMrtr() => + _negotiatedProtocolVersion == McpSessionHandler.DraftProtocolVersion; + + /// + /// Returns when the session is stateful - the same server instance handles + /// subsequent requests on the same session. The legacy backcompat resolver in + /// needs a stateful session so it can send + /// elicitation/create / sampling/createMessage / roots/list to the client and + /// retry the handler with the responses. + /// + internal bool IsStatefulSession() => + _sessionTransport is not StreamableHttpServerTransport { Stateless: true }; + + /// + public override bool IsMrtrSupported => ClientSupportsMrtr() || IsStatefulSession(); + + /// + /// Invokes a handler and catches to convert it to an + /// JSON response. When MRTR is negotiated or the server is stateless, + /// the result is serialized directly. Otherwise, input requests are resolved via standard JSON-RPC + /// calls (elicitation, sampling, roots) and the handler is retried with the responses - allowing + /// MRTR-native tools to work transparently with clients that don't support MRTR. + /// + private async Task InvokeWithInputRequiredResultHandlingAsync( + Func> handler, + JsonRpcRequest request, + CancellationToken cancellationToken) + { + const int MaxRetries = 10; + + // In stateless mode, pick up the negotiated draft protocol version from the + // transport-provided request context because there is no long-lived initialize handshake state. + if (_negotiatedProtocolVersion is null && + request.Context?.ProtocolVersion is { } headerProtocolVersion) + { + _negotiatedProtocolVersion = headerProtocolVersion; + } + + for (int retry = 0; ; retry++) + { + try + { + return await handler(request, cancellationToken).ConfigureAwait(false); + } + catch (InputRequiredException ex) + { + // If the client natively supports MRTR, serialize and return directly - + // the client will drive the retry loop. + if (ClientSupportsMrtr()) + { + return SerializeInputRequiredResult(ex.Result); + } + + // In stateless mode without MRTR, the server can't resolve input requests via + // JSON-RPC (no persistent session for server-to-client requests), and the client + // won't recognize the InputRequiredResult. This is the one unsupported configuration. + // TODO(stateless-draft): When DRAFT-2026-v1 becomes stateless-only, the IsStatefulSession() gate collapses - the stateful path will only matter for legacy clients on the current protocol. + if (!IsStatefulSession()) + { + throw new McpException( + "A tool handler returned an incomplete result, but the server is stateless and the client does not support MRTR. " + + "MRTR-native tools require either an MRTR-capable client or a stateful server for backward-compatible resolution.", ex); + } + + // Backcompat: resolve input requests via standard JSON-RPC calls and retry the handler. + if (ex.Result.InputRequests is not { Count: > 0 } inputRequests) + { + throw new McpException( + "A tool handler returned an incomplete result without input requests, and the client does not support MRTR.", ex); + } + + if (retry >= MaxRetries) + { + throw new McpException( + $"MRTR-native tool exceeded {MaxRetries} retry rounds without completing.", ex); + } + + // Resolve each input request by sending the corresponding JSON-RPC call to the client. + // Route the outgoing requests via the same DestinationBoundMcpServer used for normal tool + // handlers, so they go through the POST's response stream (RelatedTransport) rather than + // the session-level transport. Without this, the messages can race with the client's GET + // stream startup and be silently dropped by StreamableHttpServerTransport.SendMessageAsync + // when no GET request has arrived yet. + var destinationServer = CreateDestinationBoundServer(request); + var inputResponses = await ResolveInputRequestsAsync(destinationServer, inputRequests, cancellationToken).ConfigureAwait(false); + + // Reconstruct request params with inputResponses and requestState for the retry. + var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject(); + paramsObj["inputResponses"] = JsonSerializer.SerializeToNode( + (IDictionary)inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse); + + if (ex.Result.RequestState is { } requestState) + { + paramsObj["requestState"] = requestState; + } + else + { + // Strip any stale requestState carried over from the previous round's clone so + // the next tool invocation doesn't see a continuation token the current round is not using. + paramsObj.Remove("requestState"); + } + + request = new JsonRpcRequest + { + Id = request.Id, + Method = request.Method, + Params = paramsObj, + Context = request.Context, + }; + } + } + } + + /// + /// Resolves a batch of MRTR input requests concurrently by dispatching each as a standard + /// JSON-RPC request to the client. The requests are routed via + /// so they go out through the POST's response stream (matching the behavior of tool-initiated + /// server-to-client requests like server.SampleAsync) and avoid racing with the client's + /// GET stream startup. On the first failure all remaining handlers are cancelled so user-facing + /// flows (sampling/elicitation prompts) don't keep running once the caller has given up, and + /// exceptions from late-completing tasks are observed before the original exception is rethrown. + /// + private static async Task> ResolveInputRequestsAsync( + McpServer destinationServer, + IDictionary inputRequests, + CancellationToken cancellationToken) + { + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + var keyed = new (string Key, Task Task)[inputRequests.Count]; + int i = 0; + foreach (var kvp in inputRequests) + { + keyed[i++] = (kvp.Key, ResolveInputRequestAsync(destinationServer, kvp.Value, linkedCts.Token)); + } + + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + linkedCts.Cancel(); + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + // Observed; the original exception is the one we want to surface. + } + throw; + } + + var responses = new Dictionary(keyed.Length); + foreach (var (key, task) in keyed) + { + responses[key] = task.Result; + } + return responses; + } + + /// + /// Resolves a single MRTR by dispatching it as a standard JSON-RPC + /// request to the client via . This is the server-side mirror + /// of the client's input resolution logic, used for backward compatibility when the client doesn't + /// support MRTR. + /// + private static async Task ResolveInputRequestAsync(McpServer destinationServer, InputRequest inputRequest, CancellationToken cancellationToken) + { + switch (inputRequest.Method) + { + case RequestMethods.ElicitationCreate: + var elicitParams = inputRequest.ElicitationParams + ?? throw new McpException("Failed to deserialize elicitation parameters from MRTR input request."); + var elicitResult = await destinationServer.ElicitAsync(elicitParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromElicitResult(elicitResult); + + case RequestMethods.SamplingCreateMessage: + var samplingParams = inputRequest.SamplingParams + ?? throw new McpException("Failed to deserialize sampling parameters from MRTR input request."); + var samplingResult = await destinationServer.SampleAsync(samplingParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromSamplingResult(samplingResult); + + case RequestMethods.RootsList: + var rootsParams = inputRequest.RootsParams ?? new ListRootsRequestParams(); + var rootsResult = await destinationServer.RequestRootsAsync(rootsParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromRootsResult(rootsResult); + + default: + throw new McpException($"Unsupported input request method: '{inputRequest.Method}'."); + } + } + + private static JsonNode? SerializeInputRequiredResult(InputRequiredResult inputRequiredResult) => + JsonSerializer.SerializeToNode(inputRequiredResult, McpJsonUtilities.JsonContext.Default.InputRequiredResult); + + /// + /// Wraps MRTR-eligible request handlers so that when a handler calls ElicitAsync/SampleAsync/RequestRootsAsync, + /// an is returned early and the handler is suspended until the retry arrives. + /// + private void ConfigureMrtr() + { + // Wrap all methods that may trigger MRTR (server calling ElicitAsync/SampleAsync/RequestRootsAsync + // during handler execution). These methods may produce InputRequiredResult if the handler needs input. + WrapHandlerWithMrtr(RequestMethods.ToolsCall); + WrapHandlerWithMrtr(RequestMethods.PromptsGet); + WrapHandlerWithMrtr(RequestMethods.ResourcesRead); + } + + /// + /// Replaces an existing request handler entry with an MRTR-aware wrapper that supports + /// handler suspension and responses. + /// + private void WrapHandlerWithMrtr(string method) + { + if (!_requestHandlers.TryGetValue(method, out var originalHandler)) + { + return; + } + + _requestHandlers[method] = async (request, cancellationToken) => + { + // In stateless mode, each request creates a new server instance that never saw the + // initialize handshake, so _negotiatedProtocolVersion is null. Pick it up from the + // Mcp-Protocol-Version header that the transport layer flowed via JsonRpcMessageContext. + if (_negotiatedProtocolVersion is null && + request.Context?.ProtocolVersion is { } headerProtocolVersion) + { + _negotiatedProtocolVersion = headerProtocolVersion; + } + + // Check for MRTR retry: if requestState is present, look up the continuation. + if (request.Params is JsonObject paramsObj && + paramsObj.TryGetPropertyValue("requestState", out var requestStateNode) && + requestStateNode?.GetValueKind() == JsonValueKind.String && + requestStateNode.GetValue() is { } requestState) + { + if (_mrtrContinuations.TryRemove(requestState, out var existingContinuation)) + { + // Implicit MRTR retry: resume the suspended handler with client responses. + IDictionary? inputResponses = null; + if (paramsObj.TryGetPropertyValue("inputResponses", out var responsesNode) && responsesNode is not null) + { + inputResponses = JsonSerializer.Deserialize(responsesNode, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse); + } + + var exchange = existingContinuation.PendingExchange!; + var nextExchangeTask = existingContinuation.MrtrContext.ResetForNextExchange(exchange); + + if (inputResponses is not null && + inputResponses.TryGetValue(exchange.Key, out var response)) + { + if (!exchange.ResponseTcs.TrySetResult(response)) + { + throw new McpProtocolException( + $"MRTR exchange '{exchange.Key}' was already completed (possibly cancelled).", + McpErrorCode.InternalError); + } + } + else + { + if (!exchange.ResponseTcs.TrySetException( + new McpProtocolException($"Missing input response for key '{exchange.Key}'.", McpErrorCode.InvalidParams))) + { + throw new McpProtocolException( + $"MRTR exchange '{exchange.Key}' was already completed (possibly cancelled).", + McpErrorCode.InternalError); + } + } + + return await AwaitMrtrHandlerAsync( + existingContinuation.HandlerTask, existingContinuation, nextExchangeTask, cancellationToken).ConfigureAwait(false); + } + + // Explicit MRTR retry or invalid requestState: no continuation found. + // Fall through to the standard MRTR-aware invocation path below. The retry data + // (inputResponses, requestState) is already in the deserialized request params + // for low-level handlers to access, and the MrtrContext will be set up for + // high-level handlers that call ElicitAsync/SampleAsync. + } + + // Implicit MRTR (handler suspension across ElicitAsync/SampleAsync) emits + // InputRequiredResult on the wire, which only DRAFT-2026-v1 clients understand, + // and requires the same server instance to handle the retry (stateful session). + // For all other cases - legacy clients, stateless sessions - fall through to the + // exception-based path, which transparently resolves InputRequiredException via + // legacy JSON-RPC requests when the client doesn't speak MRTR. + if (!ClientSupportsMrtr() || !IsStatefulSession()) + { + return await InvokeWithInputRequiredResultHandlingAsync(originalHandler, request, cancellationToken).ConfigureAwait(false); + } + + // Start a new MRTR-aware handler invocation. + var mrtrContext = new MrtrContext(); + + // Create a long-lived CTS for the handler that survives across retries. + // The original request's combinedCts will be disposed when this lambda returns, + // breaking the cancellation chain. This CTS keeps the handler cancellable. + // Like Kestrel's HttpContext.RequestAborted, the CTS is never disposed - Cancel() + // is thread-safe with itself, and not disposing avoids deadlock risks from + // calling Cancel/Dispose inside locks or Interlocked guards. + var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + // Store the MrtrContext so CreateDestinationBoundServer can pick it up and set it + // on the per-request DestinationBoundMcpServer. This is picked up synchronously + // before any await, so the finally cleanup is safe. + _mrtrContextsByRequestId[request.Id] = mrtrContext; + Task handlerTask; + try + { + handlerTask = originalHandler(request, handlerCts.Token); + } + finally + { + _mrtrContextsByRequestId.TryRemove(request.Id, out _); + } + + // Wrap handler state into a continuation for lifecycle management across retries. + var continuation = new MrtrContinuation(handlerCts, handlerTask, mrtrContext); + + // Track the handler task for lifecycle management. The observer logs unhandled + // exceptions and decrements _mrtrInFlightCount when the handler completes, + // mirroring how McpSessionHandler tracks in-flight handlers. + Interlocked.Increment(ref _mrtrInFlightCount); + _ = ObserveHandlerCompletionAsync(handlerTask); + + return await AwaitMrtrHandlerAsync( + handlerTask, continuation, mrtrContext.InitialExchangeTask, cancellationToken).ConfigureAwait(false); + }; + } + + /// + /// Awaits the outcome of an MRTR-enabled handler invocation. + /// If the handler completes, returns its result. If an exchange arrives (handler needs input), + /// builds and returns an and stores the continuation for future retries. + /// If the handler throws , the result is returned directly + /// without storing a continuation (explicit MRTR path). + /// + private async Task AwaitMrtrHandlerAsync( + Task handlerTask, + MrtrContinuation continuation, + Task exchangeTask, + CancellationToken cancellationToken) + { + // Link the current request's cancellation to the handler's long-lived CTS. + // On the initial call this is redundant (handlerCts is already linked to cancellationToken) + // but on retries this is critical: the retry's combinedCts cancellation must flow to the handler. + // This is how notifications/cancelled for the retry's request ID reaches the handler. + using var registration = cancellationToken.Register( + static state => ((MrtrContinuation)state!).CancelHandler(), continuation); + + // Race handler against MRTR exchange. + var completedTask = await Task.WhenAny(handlerTask, exchangeTask).ConfigureAwait(false); + + if (completedTask == handlerTask) + { + // Handler completed - return its result, propagate its exception, or handle InputRequiredException. + return await AwaitHandlerWithInputRequiredResultHandlingAsync(handlerTask).ConfigureAwait(false); + } + + // Exchange arrived - handler needs input from the client (implicit MRTR path). + var exchange = await exchangeTask.ConfigureAwait(false); + + var correlationId = Guid.NewGuid().ToString("N"); + var inputRequiredResult = new InputRequiredResult + { + InputRequests = new Dictionary { [exchange.Key] = exchange.InputRequest }, + RequestState = correlationId, + }; + + // Store the continuation so the retry can resume the handler. + continuation.PendingExchange = exchange; + _mrtrContinuations[correlationId] = continuation; + + return SerializeInputRequiredResult(inputRequiredResult); + } + + /// + /// Fire-and-forget observer for an MRTR handler task. Logs unhandled exceptions at Debug + /// level (the same exception still propagates to the request pipeline, so Debug avoids + /// double-reporting at Error) and decrements when the + /// handler completes, following the same in-flight tracking pattern as . + /// + private async Task ObserveHandlerCompletionAsync(Task handlerTask) + { + try + { + await handlerTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Handler cancelled - expected lifecycle event (disposal, client cancel, session shutdown). + } + catch (InputRequiredException) + { + // Explicit MRTR: handler explicitly signaling an InputRequiredResult. Not an error. + } + catch (Exception ex) + { + MrtrHandlerError(ex); + } + finally + { + if (Interlocked.Decrement(ref _mrtrInFlightCount) == 0) + { + _allMrtrHandlersCompleted.TrySetResult(true); + } + } + } + + /// + /// Awaits a handler task, catching to convert it to an + /// JSON response without storing a continuation. + /// + private static async Task AwaitHandlerWithInputRequiredResultHandlingAsync(Task handlerTask) + { + try + { + return await handlerTask.ConfigureAwait(false); + } + catch (InputRequiredException ex) + { + return SerializeInputRequiredResult(ex.Result); + } + } + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] private partial void ToolCallError(string toolName, Exception exception); @@ -1124,6 +1612,12 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => [LoggerMessage(Level = LogLevel.Information, Message = "ReadResource \"{ResourceUri}\" completed.")] private partial void ReadResourceCompleted(string resourceUri); + [LoggerMessage(Level = LogLevel.Debug, Message = "Cancelled {Count} pending MRTR continuation(s) during session disposal.")] + private partial void MrtrContinuationsCancelled(int count); + + [LoggerMessage(Level = LogLevel.Debug, Message = "An MRTR handler threw an unhandled exception.")] + private partial void MrtrHandlerError(Exception exception); + /// /// Executes a tool call as a task and returns a CallToolTaskResult immediately. /// diff --git a/src/ModelContextProtocol.Core/Server/MrtrContext.cs b/src/ModelContextProtocol.Core/Server/MrtrContext.cs new file mode 100644 index 000000000..e849cf4eb --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/MrtrContext.cs @@ -0,0 +1,78 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Manages the MRTR (Multi Round-Trip Request) coordination between a handler and the pipeline. +/// When a handler calls or +/// , +/// the handler sets the exchange TCS and suspends on a response TCS. The pipeline detects the exchange +/// via or the task returned by , +/// sends an , and later completes the response TCS when the retry arrives. +/// +internal sealed class MrtrContext +{ + private TaskCompletionSource _exchangeTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _nextInputRequestId; + + /// + /// Gets the task for the initial MRTR exchange. Set once in the constructor and never changes. + /// For subsequent exchanges after a retry, use the task returned by . + /// + public Task InitialExchangeTask { get; } + + public MrtrContext() + { + InitialExchangeTask = _exchangeTcs.Task; + } + + /// + /// Prepares the context for the next round of exchange after a retry arrives. + /// Uses to atomically validate that + /// still references the TCS that produced , + /// ensuring concurrent calls reliably fail. + /// + /// The exchange from the previous round whose + /// response has been (or is about to be) completed. + /// A task that completes when the handler requests input via + /// . + /// The context state was modified concurrently. + public Task ResetForNextExchange(MrtrExchange previousExchange) + { + var newTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (Interlocked.CompareExchange(ref _exchangeTcs, newTcs, previousExchange.SourceTcs) != previousExchange.SourceTcs) + { + throw new InvalidOperationException("MrtrContext was modified concurrently."); + } + + return newTcs.Task; + } + + /// + /// Called by + /// or + /// to request input from the client via the MRTR mechanism. + /// + /// The input request describing what the server needs. + /// A token to cancel the wait for input. + /// The client's response to the input request. + /// A concurrent server-to-client request is already pending. + public async Task RequestInputAsync(InputRequest inputRequest, CancellationToken cancellationToken) + { + var key = $"input_{Interlocked.Increment(ref _nextInputRequestId)}"; + var tcs = _exchangeTcs; + var exchange = new MrtrExchange(key, inputRequest, tcs); + + // TrySetResult is the sole atomicity gate. If it returns false, + // the TCS was already completed by a prior call - concurrent exchanges + // are not supported. + if (!tcs.TrySetResult(exchange)) + { + throw new InvalidOperationException( + "Concurrent server-to-client requests are not supported. " + + "Await each ElicitAsync, SampleAsync, or RequestRootsAsync call before making another."); + } + + return await exchange.ResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } +} diff --git a/src/ModelContextProtocol.Core/Server/MrtrContinuation.cs b/src/ModelContextProtocol.Core/Server/MrtrContinuation.cs new file mode 100644 index 000000000..f2cc65e3f --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/MrtrContinuation.cs @@ -0,0 +1,50 @@ +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Server; + +/// +/// Represents the lifecycle state for an MRTR handler invocation across retries. +/// Created when the handler starts and stored in _mrtrContinuations when +/// the handler suspends waiting for client input. +/// +internal sealed class MrtrContinuation +{ + private readonly CancellationTokenSource _handlerCts; + + public MrtrContinuation(CancellationTokenSource handlerCts, Task handlerTask, MrtrContext mrtrContext) + { + _handlerCts = handlerCts; + HandlerTask = handlerTask; + MrtrContext = mrtrContext; + } + + /// + /// Gets a token that cancels when the handler should be aborted. + /// Passed to the handler at creation and remains valid across retries. + /// + public CancellationToken HandlerToken => _handlerCts.Token; + + /// + /// The handler task that is suspended awaiting input. + /// + public Task HandlerTask { get; } + + /// + /// The MRTR context for the handler's async flow. + /// + public MrtrContext MrtrContext { get; } + + /// + /// The exchange that is awaiting a response from the client. + /// Set each time the handler suspends on a new exchange. + /// + public MrtrExchange? PendingExchange { get; set; } + + /// + /// Cancels the handler. Safe to call multiple times and concurrently - + /// is thread-safe with itself. + /// The CTS is intentionally never disposed to avoid deadlock risks from + /// calling Cancel/Dispose inside synchronization primitives. + /// + public void CancelHandler() => _handlerCts.Cancel(); +} diff --git a/src/ModelContextProtocol.Core/Server/MrtrExchange.cs b/src/ModelContextProtocol.Core/Server/MrtrExchange.cs new file mode 100644 index 000000000..cf0a86af4 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/MrtrExchange.cs @@ -0,0 +1,41 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Represents a single exchange between the handler and the pipeline during an MRTR flow. +/// The handler creates the exchange and awaits the response TCS. The pipeline reads the exchange, +/// sends the to the client, and completes the TCS when the response arrives. +/// +internal sealed class MrtrExchange +{ + public MrtrExchange(string key, InputRequest inputRequest, TaskCompletionSource sourceTcs) + { + Key = key; + InputRequest = inputRequest; + SourceTcs = sourceTcs; + ResponseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + + /// + /// The unique key identifying this exchange within the MRTR round trip. + /// + public string Key { get; } + + /// + /// The input request that needs to be fulfilled by the client. + /// + public InputRequest InputRequest { get; } + + /// + /// The that this exchange was set as the result of. + /// Used by on retry to validate + /// the expected state via . + /// + internal TaskCompletionSource SourceTcs { get; } + + /// + /// The TCS that will be completed with the client's response. + /// + public TaskCompletionSource ResponseTcs { get; } +} diff --git a/tests/Common/Utils/NodeHelpers.cs b/tests/Common/Utils/NodeHelpers.cs index a30dd3fc3..ef1686abb 100644 --- a/tests/Common/Utils/NodeHelpers.cs +++ b/tests/Common/Utils/NodeHelpers.cs @@ -205,6 +205,44 @@ public static bool HasSep2243Scenarios() } } + /// + /// Checks whether the SEP-2322 (Multi Round-Trip Requests / IncompleteResult) + /// conformance scenarios are available by reading the conformance package version + /// from the repo's package.json. MRTR scenarios require a conformance package version + /// that includes SEP-2322 support (see + /// https://github.com/modelcontextprotocol/conformance/pull/188). + /// + public static bool HasMrtrScenarios() + { + try + { + var repoRoot = FindRepoRoot(); + var packageJsonPath = Path.Combine(repoRoot, "package.json"); + if (!File.Exists(packageJsonPath)) + { + return false; + } + + var json = System.Text.Json.JsonDocument.Parse(File.ReadAllText(packageJsonPath)); + if (json.RootElement.TryGetProperty("dependencies", out var deps) && + deps.TryGetProperty("@modelcontextprotocol/conformance", out var versionElement)) + { + var versionStr = versionElement.GetString(); + if (versionStr is not null && Version.TryParse(versionStr, out var version)) + { + // SEP-2322 scenarios are expected in conformance package >= 0.2.0 + return version >= new Version(0, 2, 0); + } + } + + return false; + } + catch + { + return false; + } + } + private static ProcessStartInfo NpmStartInfo(string arguments, string workingDirectory) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) diff --git a/tests/Common/Utils/ServerMessageTracker.cs b/tests/Common/Utils/ServerMessageTracker.cs new file mode 100644 index 000000000..66a80c681 --- /dev/null +++ b/tests/Common/Utils/ServerMessageTracker.cs @@ -0,0 +1,95 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Utils; + +/// +/// Tracks MRTR protocol mode via incoming and outgoing message filters. +/// Used by MRTR tests to verify the correct protocol mode (MRTR vs legacy) was used. +/// +internal sealed class ServerMessageTracker +{ + private static readonly HashSet LegacyMrtrMethods = + [ + RequestMethods.ElicitationCreate, + RequestMethods.SamplingCreateMessage, + RequestMethods.RootsList, + ]; + + private readonly ConcurrentBag _legacyRequestMethods = []; + private int _mrtrRetryCount; + private int _incompleteResultCount; + + /// + /// Adds incoming and outgoing message filters to track MRTR protocol usage. + /// Call this in services.Configure<McpServerOptions> or AddMcpServer callbacks. + /// + public void AddFilters(McpMessageFilters messageFilters) + { + // Track outgoing legacy JSON-RPC requests and InputRequiredResult responses. + messageFilters.OutgoingFilters.Add(next => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && LegacyMrtrMethods.Contains(request.Method)) + { + _legacyRequestMethods.Add(request.Method); + } + else if (context.JsonRpcMessage is JsonRpcResponse response && + response.Result is JsonObject resultObj && + resultObj.TryGetPropertyValue("resultType", out var resultTypeNode) && + resultTypeNode?.GetValue() == "input_required") + { + Interlocked.Increment(ref _incompleteResultCount); + } + + await next(context, cancellationToken); + }); + + // Track incoming MRTR retries (requests with inputResponses or requestState in params). + messageFilters.IncomingFilters.Add(next => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && + request.Params is JsonObject paramsObj && + (paramsObj.ContainsKey("inputResponses") || paramsObj.ContainsKey("requestState"))) + { + Interlocked.Increment(ref _mrtrRetryCount); + } + + await next(context, cancellationToken); + }); + } + + /// + /// Asserts that MRTR was used: at least one InputRequiredResult response was sent + /// and no legacy JSON-RPC requests (elicitation/create, sampling/createMessage, roots/list) were sent. + /// + public void AssertMrtrUsed() + { + Assert.True(_incompleteResultCount > 0, + "Expected at least one InputRequiredResult response (MRTR mode), but none were detected."); + Assert.Empty(_legacyRequestMethods); + } + + /// + /// Asserts that MRTR was used at least once (at least one InputRequiredResult response was sent), + /// independent of whether the session also issued any legacy server-to-client requests. + /// + public void AssertMrtrUsedAtLeastOnce() + { + Assert.True(_incompleteResultCount > 0, + "Expected at least one InputRequiredResult response (MRTR mode), but none were detected."); + } + + /// + /// Asserts that legacy mode was used: at least one legacy JSON-RPC request was sent + /// and no MRTR retries or InputRequiredResult responses were detected. + /// + public void AssertMrtrNotUsed() + { + Assert.NotEmpty(_legacyRequestMethods); + Assert.Equal(0, _mrtrRetryCount); + Assert.Equal(0, _incompleteResultCount); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs index 5552b5395..ac19953bf 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs @@ -1,7 +1,45 @@ -namespace ModelContextProtocol.AspNetCore.Tests; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpStatelessTests(ITestOutputHelper outputHelper) : MapMcpStreamableHttpTests(outputHelper) { protected override bool UseStreamableHttp => true; protected override bool Stateless => true; + + [Fact] + public async Task EnablePollingAsync_ThrowsInvalidOperationException_InStatelessMode() + { + InvalidOperationException? capturedException = null; + var pollingTool = McpServerTool.Create(async (RequestContext context) => + { + try + { + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + } + catch (InvalidOperationException ex) + { + capturedException = ex; + } + + return "Complete"; + }, options: new() { Name = "polling_tool" }); + + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools([pollingTool]); + + await using var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync(); + + await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(capturedException); + Assert.Contains("stateless", capturedException.Message, StringComparison.OrdinalIgnoreCase); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 40b9e8217..3d532802b 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -348,9 +348,9 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectAsync(clientOptions: new() + await using var mcpClient = await ConnectAsync(configureClient: options => { - ProtocolVersion = "2025-06-18", + options.ProtocolVersion = "2025-06-18"; }); Assert.Equal("2025-06-18", mcpClient.NegotiatedProtocolVersion); @@ -458,41 +458,6 @@ public async Task CanResumeSessionWithMapMcpAndRunSessionHandler() Assert.Equal(1, runSessionCount); } - [Fact] - public async Task EnablePollingAsync_ThrowsInvalidOperationException_InStatelessMode() - { - Assert.SkipUnless(Stateless, "This test only applies to stateless mode."); - - InvalidOperationException? capturedException = null; - var pollingTool = McpServerTool.Create(async (RequestContext context) => - { - try - { - await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); - } - catch (InvalidOperationException ex) - { - capturedException = ex; - } - - return "Complete"; - }, options: new() { Name = "polling_tool" }); - - Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools([pollingTool]); - - await using var app = Builder.Build(); - app.MapMcp(); - - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var mcpClient = await ConnectAsync(); - - await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(capturedException); - Assert.Contains("stateless", capturedException.Message, StringComparison.OrdinalIgnoreCase); - } - [Fact] public async Task EnablePollingAsync_ThrowsInvalidOperationException_WhenNoEventStreamStoreConfigured() { @@ -793,13 +758,13 @@ public async Task EndpointFilter_CanReadSessionId_BeforeAndAfterHandler() await using var app = Builder.Build(); - // This is the pattern documented in sessions.md — verify it actually works. + // This is the pattern documented in sessions.md - verify it actually works. // Tag before next() so child spans inherit the value. app.MapMcp().AddEndpointFilter(async (context, next) => { var httpContext = context.HttpContext; - // Read from request headers — available on all non-initialize requests in stateful mode. + // Read from request headers - available on all non-initialize requests in stateful mode. string? beforeSessionId = httpContext.Request.Headers["Mcp-Session-Id"]; // Tag before next() so child activities created during the handler inherit it. @@ -828,7 +793,7 @@ public async Task EndpointFilter_CanReadSessionId_BeforeAndAfterHandler() await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); // The filter must have observed at least one MCP request. Don't assert an exact - // minimum — the initialized notification or GET stream may not have completed yet. + // minimum - the initialized notification or GET stream may not have completed yet. Assert.NotEmpty(capturedSessionIds); if (Stateless) @@ -855,7 +820,7 @@ public async Task EndpointFilter_CanReadSessionId_BeforeAndAfterHandler() }); // At least one POST should have the session ID in the request header too - // (the initialized notification or list_tools — but not the initial initialize request). + // (the initialized notification or list_tools - but not the initial initialize request). Assert.Contains(postCaptures, c => c.BeforeNext == client.SessionId); // Verify Activity.Current was available and the AddTag pattern works before next(). diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs new file mode 100644 index 000000000..ddae6c66b --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs @@ -0,0 +1,779 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public abstract partial class MapMcpTests +{ + private ServerMessageTracker ConfigureServer(params Delegate[] tools) + { + var messageTracker = new ServerMessageTracker(); + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation { Name = "MrtrTestServer", Version = "1" }; + // Do not pin a protocol version - let it be negotiated based on what the client requests. + // DRAFT-2026-v1 is in SupportedProtocolVersions, so an opt-in client gets it; others get + // the latest non-draft. + messageTracker.AddFilters(options.Filters.Message); + }) + .WithHttpTransport(ConfigureStateless) + .WithTools(tools.Select(t => McpServerTool.Create(t))); + return messageTracker; + } + + private Task ConnectExperimentalAsync() => + ConnectAsync(configureClient: options => + { + ConfigureMrtrHandlers(options); + options.ProtocolVersion = "DRAFT-2026-v1"; + }); + + private Task ConnectDefaultAsync() => + ConnectAsync(configureClient: ConfigureMrtrHandlers); + + /// Configures elicitation, sampling, and roots handlers on client options. + private static void ConfigureMrtrHandlers(McpClientOptions options) + { + options.Handlers.ElicitationHandler = (request, ct) => + { + var message = request?.Message ?? ""; + var answer = message.Contains("name", StringComparison.OrdinalIgnoreCase) ? "Alice" + : message.Contains("greet", StringComparison.OrdinalIgnoreCase) ? "Hello" + : "yes"; + + return new ValueTask(new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["answer"] = JsonDocument.Parse($"\"{answer}\"").RootElement.Clone() + } + }); + }; + options.Handlers.SamplingHandler = (request, progress, ct) => + { + var prompt = request?.Messages?.LastOrDefault()?.Content + .OfType().FirstOrDefault()?.Text ?? ""; + return new ValueTask(new CreateMessageResult + { + Content = [new TextContentBlock { Text = $"LLM:{prompt}" }], + Model = "test-model" + }); + }; + options.Handlers.RootsHandler = (request, ct) => + { + return new ValueTask(new ListRootsResult + { + Roots = [ + new Root { Uri = "file:///project", Name = "Project" }, + new Root { Uri = "file:///data", Name = "Data" } + ] + }); + }; + } + + // ===================================================================== + // MRTR tests: experimental (native), backcompat (legacy JSON-RPC), and edge cases. + // Each test creates its own server with DRAFT-2026-v1 enabled. + // ===================================================================== + + [McpServerTool(Name = "mrtr-mixed")] + private static async Task MrtrMixed(McpServer server, RequestContext context, CancellationToken ct) + { + var state = context.Params!.RequestState; + var responses = context.Params!.InputResponses; + + // Round 3 entry: confirmation from round 2 available. Transition to await API. + if (state == "round-2" && responses?.TryGetValue("confirm", out var confirmResponse) == true) + { + var confirmation = confirmResponse.Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action ?? "unknown"; + + // Await API: sequential sampling then elicitation + var sampleResult = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Write greeting" }] }], + MaxTokens = 100 + }, ct); + var greeting = sampleResult.Content.OfType().FirstOrDefault()?.Text ?? ""; + + var signoffResult = await server.ElicitAsync(new ElicitRequestParams + { + Message = "Sign off as?", + RequestedSchema = new() + }, ct); + var signoff = signoffResult.Action; + + return $"{confirmation}|{greeting}|{signoff}"; + } + + // Round 2 entry: parallel results from round 1 available. + if (state == "round-1" && responses is not null) + { + var name = responses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + var weather = responses["weather"].Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content + .OfType().FirstOrDefault()?.Text ?? ""; + var root = responses["roots"].Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots?.FirstOrDefault()?.Name ?? ""; + + // Exception API: single elicitation with requestState + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"Confirm {name} in {weather} near {root}?", + RequestedSchema = new() + }) + }, + requestState: "round-2"); + } + + // Round 1: Exception API with 3 PARALLEL input requests + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new() + }), + ["weather"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Describe the weather" }] }], + MaxTokens = 100 + }), + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "round-1"); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Mrtr_MixedExceptionAndAwaitStyle(bool experimentalClient) + { + // The server always supports DRAFT-2026-v1 (it's in SupportedProtocolVersions). The + // client opts in by pinning ProtocolVersion = "DRAFT-2026-v1"; otherwise it negotiates + // the latest non-draft version and the server falls back to the exception path with + // legacy JSON-RPC resolution. + var messageTracker = ConfigureServer(MrtrMixed); + + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + + // The await-style portion of this tool calls server.SampleAsync/ElicitAsync on round 3. + // In stateless mode, those calls succeed only when the request is still open on the same + // SSE stream - which it is - so the tool runs end-to-end as long as the input requests + // themselves can be resolved (MRTR client) or replayed via legacy JSON-RPC (stateful + legacy). + if (Stateless && !experimentalClient) + { + // Stateless + legacy client: InputRequiredException cannot be resolved (no MRTR wire + // and no persistent server instance for the backcompat retry loop). The server returns + // a JSON-RPC error. + await using var client = await ConnectAsync(configureClient: configureClient); + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-mixed", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + Assert.Contains("stateless", ex.Message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("MRTR", ex.Message); + return; + } + + if (Stateless && experimentalClient) + { + // Stateless + MRTR client: the await-style portion (server.SampleAsync on round 3) + // requires handler suspension across requests, which only works in stateful mode. + // Skip this combination - the await API is documented as stateful-only. + Assert.SkipWhen(true, "Await-style API requires handler suspension (stateful only)."); + return; + } + + // Stateful path - both client modes complete all 3 rounds. + await using var statefulClient = await ConnectAsync(configureClient: configureClient); + + Assert.Equal(experimentalClient ? "DRAFT-2026-v1" : "2025-11-25", + statefulClient.NegotiatedProtocolVersion); + + var result = await statefulClient.CallToolAsync("mrtr-mixed", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.True(result.IsError is not true); + var parts = text.Split('|'); + Assert.Equal(3, parts.Length); + Assert.Equal("accept", parts[0]); + Assert.StartsWith("LLM:", parts[1]); + Assert.Equal("accept", parts[2]); + + if (experimentalClient) + { + // Rounds 1-2 use wire-format MRTR (InputRequiredResult), but round 3's await calls + // still issue legacy elicitation/create + sampling/createMessage requests, so this + // configuration is mixed-mode. + messageTracker.AssertMrtrUsedAtLeastOnce(); + } + else + { + messageTracker.AssertMrtrNotUsed(); + } + } + + [McpServerTool(Name = "mrtr-parallel-await")] + private static async Task MrtrParallelAwait(McpServer server, CancellationToken ct) + { + var elicitTask = server.ElicitAsync(new ElicitRequestParams + { + Message = "Parallel elicit", + RequestedSchema = new() + }, ct); + + // Start the second await - with MRTR, this throws InvalidOperationException + // because MrtrContext only supports one pending exchange at a time. + try + { + var sampleTask = server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Parallel sample" }] }], + MaxTokens = 100 + }, ct); + + // If we get here, both calls succeeded (non-MRTR path) + var sampleResult = await sampleTask; + var elicitResult = await elicitTask; + return $"parallel-ok:{elicitResult.Action}:{sampleResult.Content.OfType().First().Text}"; + } + catch (InvalidOperationException ex) + { + return ex.Message; + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Mrtr_ParallelAwaits(bool experimentalClient) + { + // Parallel awaits work with regular JSON-RPC but fail with MRTR because + // MrtrContext only supports one exchange at a time (TrySetResult gate). + Assert.SkipWhen(Stateless, "Await-style API requires handler suspension (stateful only)."); + + ConfigureServer(MrtrParallelAwait); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + + await using var client = await ConnectAsync(configureClient: configureClient); + + if (experimentalClient) + { + // MRTR active. Parallel awaits hit the MrtrContext concurrency gate and the second + // call throws InvalidOperationException, which the tool catches and returns as text. + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-parallel-await", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Contains("Concurrent server-to-client requests are not supported", text); + Assert.True(result.IsError is not true); + } + else + { + // Non-MRTR: awaits go through regular JSON-RPC - concurrent calls work. + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-parallel-await", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.StartsWith("parallel-ok:", text); + Assert.True(result.IsError is not true); + } + } + + [McpServerTool(Name = "mrtr-elicit")] + private static string MrtrElicit(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_input", out var response)) + { + return $"elicit-ok:{response.Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_input"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm", + RequestedSchema = new() + }) + }, + requestState: "elicit-state"); + } + + [Fact] + public async Task Mrtr_Roots_CompletesViaMrtr() + { + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-roots")] (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("roots", out var response)) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots; + return $"roots-ok:{string.Join(",", roots?.Select(r => r.Uri) ?? [])}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "roots-state"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectExperimentalAsync(); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-roots", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("roots-ok:file:///project,file:///data", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrUsed(); + } + + [McpServerTool(Name = "mrtr-multi")] + private static string MrtrMulti(RequestContext context) + { + var requestState = context.Params!.RequestState; + var inputResponses = context.Params!.InputResponses; + + if (requestState == "round-2" && inputResponses is not null) + { + var greeting = inputResponses["greeting"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action; + return $"multi-done:greeting={greeting}"; + } + + if (requestState == "round-1" && inputResponses is not null) + { + var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["greeting"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"How should I greet {name}?", + RequestedSchema = new() + }) + }, + requestState: "round-2"); + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new() + }) + }, + requestState: "round-1"); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Mrtr_MultiRoundTrip_Completes(bool experimentalClient) + { + var messageTracker = ConfigureServer(MrtrMulti); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + // Configure client - experimental or default based on parameter. + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + await using var client = await ConnectAsync(configureClient: configureClient); + + if (!experimentalClient && Stateless) + { + // Stateless without MRTR: InputRequiredException can't be resolved + // (no MRTR negotiated and no stateful backcompat path). + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-multi", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + return; + } + + var result = await client.CallToolAsync("mrtr-multi", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("multi-done:greeting=accept", text); + Assert.True(result.IsError is not true); + + if (experimentalClient) + { + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + messageTracker.AssertMrtrUsed(); + } + else + { + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + messageTracker.AssertMrtrNotUsed(); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Mrtr_IsMrtrSupported(bool experimentalClient) + { + ConfigureServer([McpServerTool(Name = "mrtr-check")] (McpServer server) => server.IsMrtrSupported.ToString()); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + // Configure client - experimental or default based on parameter. + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + await using var client = await ConnectAsync(configureClient: configureClient); + Assert.Equal(experimentalClient ? "DRAFT-2026-v1" : "2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-check", + cancellationToken: TestContext.Current.CancellationToken); + + // IsMrtrSupported is false only when stateless AND client didn't negotiate MRTR + // (no backcompat path available). All other combos have MRTR or backcompat support. + var expected = Stateless && !experimentalClient ? "False" : "True"; + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal(expected, text); + } + + [McpServerTool(Name = "mrtr-concurrent-three")] + private static string MrtrConcurrentThree(RequestContext context) + { + if (context.Params!.InputResponses is { Count: 3 } responses && + responses.ContainsKey("elicit") && + responses.ContainsKey("sample") && + responses.ContainsKey("roots")) + { + var elicitAction = responses["elicit"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action; + var sampleText = responses["sample"].Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)? + .Content.OfType().FirstOrDefault()?.Text; + var rootUris = string.Join(",", + responses["roots"].Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots.Select(r => r.Uri) ?? []); + return $"all-ok:elicit={elicitAction},sample={sampleText},roots={rootUris}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["elicit"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Confirm action", + RequestedSchema = new() + }), + ["sample"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Generate summary" }] + }], + MaxTokens = 50 + }), + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "concurrent-state"); + } + + [Fact] + public async Task Mrtr_ConcurrentThreeInputs_ResolvedSimultaneously() + { + var messageTracker = ConfigureServer(MrtrConcurrentThree); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + var elicitCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var samplingCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var rootsCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await using var client = await ConnectAsync(configureClient: options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + options.Handlers.ElicitationHandler = async (request, ct) => + { + elicitCalled.TrySetResult(); + await Task.WhenAll(samplingCalled.Task.WaitAsync(ct), rootsCalled.Task.WaitAsync(ct)); + return new ElicitResult { Action = "accept" }; + }; + options.Handlers.SamplingHandler = async (request, progress, ct) => + { + samplingCalled.TrySetResult(); + await Task.WhenAll(elicitCalled.Task.WaitAsync(ct), rootsCalled.Task.WaitAsync(ct)); + return new CreateMessageResult + { + Content = [new TextContentBlock { Text = "AI-summary" }], + Model = "test-model" + }; + }; + options.Handlers.RootsHandler = async (request, ct) => + { + rootsCalled.TrySetResult(); + await Task.WhenAll(elicitCalled.Task.WaitAsync(ct), samplingCalled.Task.WaitAsync(ct)); + return new ListRootsResult + { + Roots = [new Root { Uri = "file:///workspace", Name = "Workspace" }] + }; + }; + }); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-concurrent-three", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("all-ok:elicit=accept,sample=AI-summary,roots=file:///workspace", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task Mrtr_LoadShedding_RequestStateOnly_CompletesViaMrtr() + { + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-loadshed")] (RequestContext context) => + { + if (context.Params!.RequestState is { } state) + { + return $"resumed:{state}"; + } + + // requestState-only InputRequiredException (no inputRequests) + throw new InputRequiredException(requestState: "deferred-work"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectExperimentalAsync(); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-loadshed", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("resumed:deferred-work", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task Mrtr_Backcompat_Roots_ResolvedViaLegacyJsonRpc() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-roots-backcompat")] (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("roots", out var response)) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots; + return $"roots-ok:{roots?.FirstOrDefault()?.Name}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "roots-state"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectDefaultAsync(); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-roots-backcompat", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("roots-ok:Project", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrNotUsed(); + } + + [Fact] + public async Task Mrtr_Backcompat_MultipleInputRequests_ResolvedViaLegacyJsonRpc() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-multi-input")] (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("confirm", out var elicitResponse) && + responses.TryGetValue("summarize", out var sampleResponse)) + { + var action = elicitResponse.Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action; + var text = sampleResponse.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content.OfType().FirstOrDefault()?.Text; + return $"both:{action}:{text}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm", + RequestedSchema = new() + }), + ["summarize"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Summarize" }] + }], + MaxTokens = 100 + }) + }, + requestState: "multi-input-state"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectDefaultAsync(); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-multi-input", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("both:accept:LLM:Summarize", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrNotUsed(); + } + + [Fact] + public async Task Mrtr_Backcompat_AlwaysIncomplete_FailsAfterMaxRetries() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + int elicitCallCount = 0; + + ConfigureServer( + [McpServerTool(Name = "mrtr-always-incomplete")] (RequestContext context) => + { + // Always throw - never complete + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Confirm again", + RequestedSchema = new() + }) + }, + requestState: "infinite"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectAsync(configureClient: options => + { + ConfigureMrtrHandlers(options); + var originalHandler = options.Handlers.ElicitationHandler!; + options.Handlers.ElicitationHandler = (request, ct) => + { + Interlocked.Increment(ref elicitCallCount); + return originalHandler(request, ct); + }; + }); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-always-incomplete", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("exceeded", ex.Message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("10", ex.Message); + Assert.Equal(10, elicitCallCount); + } + + [Fact] + public async Task Mrtr_Backcompat_EmptyInputRequests_FailsWithError() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + ConfigureServer( + [McpServerTool(Name = "mrtr-empty-inputs")] (RequestContext context) => + { + throw new InputRequiredException( + inputRequests: new Dictionary(), + requestState: "empty"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectDefaultAsync(); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-empty-inputs", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("without input requests", ex.Message, StringComparison.OrdinalIgnoreCase); + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + } + + [Fact] + public async Task Mrtr_Backcompat_ClientHandlerThrows_PropagatesError() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + + ConfigureServer(MrtrElicit); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectAsync(configureClient: options => + { + ConfigureMrtrHandlers(options); + options.Handlers.ElicitationHandler = (request, ct) => + { + throw new InvalidOperationException("Client-side elicitation failure"); + }; + }); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + // Handler exception propagates through the backcompat JSON-RPC round-trip. + // The original exception message gets wrapped in "Request failed (remote)" during backcompat. + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-elicit", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 678b27022..b9b8381ca 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -14,7 +14,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public abstract class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +public abstract partial class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { protected abstract bool UseStreamableHttp { get; } protected abstract bool Stateless { get; } @@ -27,9 +27,8 @@ protected virtual void ConfigureStateless(HttpServerTransportOptions options) protected async Task ConnectAsync( string? path = null, HttpClientTransportOptions? transportOptions = null, - McpClientOptions? clientOptions = null) + Action? configureClient = null) { - // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; await using var transport = new HttpClientTransport(transportOptions ?? new HttpClientTransportOptions @@ -38,6 +37,8 @@ protected async Task ConnectAsync( TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); + var clientOptions = new McpClientOptions(); + configureClient?.Invoke(clientOptions); return await McpClient.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); } @@ -156,29 +157,24 @@ public async Task Sampling_DoesNotCloseStreamPrematurely() await app.StartAsync(TestContext.Current.CancellationToken); var sampleCount = 0; - var clientOptions = new McpClientOptions() + await using var mcpClient = await ConnectAsync(configureClient: options => { - Handlers = new() + options.Handlers.SamplingHandler = async (parameters, _, _) => { - SamplingHandler = async (parameters, _, _) => - { - Assert.NotNull(parameters?.Messages); - var message = Assert.Single(parameters.Messages); - Assert.Equal(Role.User, message.Role); - Assert.Equal("Test prompt for sampling", Assert.IsType(Assert.Single(message.Content)).Text); + Assert.NotNull(parameters?.Messages); + var message = Assert.Single(parameters.Messages); + Assert.Equal(Role.User, message.Role); + Assert.Equal("Test prompt for sampling", Assert.IsType(Assert.Single(message.Content)).Text); - sampleCount++; - return new CreateMessageResult - { - Model = "test-model", - Role = Role.Assistant, - Content = [new TextContentBlock { Text = "Sampling response from client" }], - }; - } - } - }; - - await using var mcpClient = await ConnectAsync(clientOptions: clientOptions); + sampleCount++; + return new CreateMessageResult + { + Model = "test-model", + Role = Role.Assistant, + Content = [new TextContentBlock { Text = "Sampling response from client" }], + }; + }; + }); var result = await mcpClient.CallToolAsync("sampling-tool", new Dictionary { @@ -375,7 +371,11 @@ public async Task OutgoingFilter_SeesResponsesAndRequests() }, }; - await using var client = await ConnectAsync(clientOptions: clientOptions); + await using var client = await ConnectAsync(configureClient: opts => + { + opts.Capabilities = clientOptions.Capabilities; + opts.Handlers = clientOptions.Handlers; + }); await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); await client.CallToolAsync("echo_claims_principal", @@ -385,10 +385,12 @@ await client.CallToolAsync("sampling-tool", new Dictionary { ["prompt"] = "Hello" }, cancellationToken: TestContext.Current.CancellationToken); - Assert.Contains("initialize-response", observedMessageTypes); - Assert.Contains("tools-list-response", observedMessageTypes); - Assert.Contains("tool-call-response", observedMessageTypes); - Assert.Contains($"request:{RequestMethods.SamplingCreateMessage}", observedMessageTypes); + // Exact counts catch regressions where the outgoing filter pipeline gets applied more than once + // per outbound message (e.g., SendRequestAsync double-wrapping SendToRelatedTransportAsync). + Assert.Equal(1, observedMessageTypes.Count(m => m == "initialize-response")); + Assert.Equal(1, observedMessageTypes.Count(m => m == "tools-list-response")); + Assert.Equal(2, observedMessageTypes.Count(m => m == "tool-call-response")); // one per CallToolAsync + Assert.Equal(2, observedMessageTypes.Count(m => m == $"request:{RequestMethods.SamplingCreateMessage}")); // sampling-tool makes two SampleAsync calls } [Fact] @@ -496,6 +498,7 @@ public async Task OutgoingFilter_CanSendAdditionalMessages() Assert.Equal("injected", extraMessage); } + private ClaimsPrincipal CreateUser(string name) => new(new ClaimsIdentity( [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)], @@ -566,4 +569,5 @@ public static async Task LongRunningOperation( return $"Operation completed after {durationMs}ms"; } } + } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MrtrProtocolTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MrtrProtocolTests.cs new file mode 100644 index 000000000..6be82aec0 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MrtrProtocolTests.cs @@ -0,0 +1,469 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Protocol-level tests for Multi Round-Trip Requests (MRTR). +/// These tests send raw JSON-RPC requests via HTTP and verify protocol-level behavior +/// including InputRequiredResult structure, retry with inputResponses, and error handling. +/// +public class MrtrProtocolTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private WebApplication? _app; + + private async Task StartAsync() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = nameof(MrtrProtocolTests), + Version = "1", + }; + options.ProtocolVersion = "DRAFT-2026-v1"; + }).WithTools([ + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + return $"{result.Action}:{result.Content?.FirstOrDefault().Value}"; + }, + new McpServerToolCreateOptions + { + Name = "elicit-tool", + Description = "Elicits from client" + }), + McpServerTool.Create( + static string (McpServer _) => throw new McpProtocolException("Tool validation failed", McpErrorCode.InvalidParams), + new McpServerToolCreateOptions + { + Name = "throwing-tool", + Description = "A tool that throws immediately" + }), + ]).WithHttpTransport(); + + _app = Builder.Build(); + _app.MapMcp(); + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + [Fact] + public async Task ToolThatThrows_ReturnsJsonRpcError_NotIncompleteResult() + { + await StartAsync(); + await InitializeWithMrtrAsync(); + + var response = await PostJsonRpcAsync(CallTool("throwing-tool")); + + // Should be a JSON-RPC error, not an InputRequiredResult + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var sseData = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken)); + var message = JsonSerializer.Deserialize(sseData, McpJsonUtilities.DefaultOptions); + var error = Assert.IsType(message); + Assert.Equal((int)McpErrorCode.InvalidParams, error.Error.Code); + Assert.Contains("Tool validation failed", error.Error.Message); + } + + [Fact] + public async Task RetryWithInvalidRequestState_ReturnsJsonRpcError() + { + await StartAsync(); + await InitializeWithMrtrAsync(); + + // Send a retry with a requestState that doesn't match any active continuation + var retryParams = new JsonObject + { + ["name"] = "elicit-tool", + ["arguments"] = new JsonObject { ["message"] = "test" }, + ["inputResponses"] = new JsonObject { ["key1"] = new JsonObject { ["action"] = "confirm" } }, + ["requestState"] = "nonexistent-state-id" + }; + + var response = await PostJsonRpcAsync(Request("tools/call", retryParams.ToJsonString())); + + // Read as a generic JsonRpcMessage to check if it's an error + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var sseData = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken)); + var message = JsonSerializer.Deserialize(sseData, McpJsonUtilities.DefaultOptions); + + // Invalid requestState should result in a fresh tool invocation + // (the tool will return InputRequiredResult since it calls ElicitAsync) + // or an error, depending on the implementation. + // In our implementation, unrecognized requestState triggers a new invocation. + Assert.True( + message is JsonRpcResponse or JsonRpcError, + $"Expected JsonRpcResponse or JsonRpcError, got {message?.GetType().Name}"); + } + + [Fact] + public async Task SessionDelete_CancelsPendingMrtrContinuation() + { + await StartAsync(); + await InitializeWithMrtrAsync(); + + // 1. Call a tool that suspends at ElicitAsync (implicit MRTR path). + var response = await PostJsonRpcAsync(CallTool("elicit-tool", """{"message":"Please confirm"}""")); + var rpcResponse = await AssertSingleSseResponseAsync(response); + + // Verify we got an InputRequiredResult (handler is now suspended, continuation stored). + var resultObj = Assert.IsType(rpcResponse.Result); + Assert.Equal("input_required", resultObj["resultType"]?.GetValue()); + var requestState = resultObj["requestState"]!.GetValue(); + Assert.False(string.IsNullOrEmpty(requestState)); + + // 2. DELETE the session while the handler is suspended. + using var deleteResponse = await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, deleteResponse.StatusCode); + + // Poll for the async cancellation to propagate through the handler task. + // Under thread pool starvation, this can take significantly longer than 100ms. + var deadline = DateTime.UtcNow.AddSeconds(30); + while (true) + { + if (MockLoggerProvider.LogMessages.Any(m => m.Message.Contains("pending MRTR continuation")) + || DateTime.UtcNow >= deadline) + { + break; + } + + await Task.Delay(100, TestContext.Current.CancellationToken); + } + + // 3. Verify that the MRTR cancellation was logged at Debug level. + var mrtrCancelledLog = MockLoggerProvider.LogMessages + .Where(m => m.Message.Contains("pending MRTR continuation")) + .ToList(); + Assert.Single(mrtrCancelledLog); + Assert.Equal(LogLevel.Debug, mrtrCancelledLog[0].LogLevel); + Assert.Contains("1", mrtrCancelledLog[0].Message); + + // 4. Verify no error-level log was emitted for the cancellation. + // The handler's OperationCanceledException should be silently observed, not logged as an error. + var errorLogs = MockLoggerProvider.LogMessages + .Where(m => m.LogLevel >= LogLevel.Error && m.Message.Contains("elicit")) + .ToList(); + Assert.Empty(errorLogs); + } + + [Fact] + public async Task SessionDelete_RetryAfterDelete_ReturnsSessionNotFound() + { + await StartAsync(); + await InitializeWithMrtrAsync(); + + // 1. Call a tool that suspends at ElicitAsync. + var response = await PostJsonRpcAsync(CallTool("elicit-tool", """{"message":"Please confirm"}""")); + var rpcResponse = await AssertSingleSseResponseAsync(response); + + var resultObj = Assert.IsType(rpcResponse.Result); + var requestState = resultObj["requestState"]!.GetValue(); + var inputRequests = resultObj["inputRequests"]!.AsObject(); + var inputKey = inputRequests.First().Key; + + // 2. DELETE the session. + using var deleteResponse = await HttpClient.DeleteAsync("", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, deleteResponse.StatusCode); + + // 3. Attempt to retry with the old requestState - session is gone. + var inputResponse = InputResponse.FromElicitResult(new ElicitResult { Action = "accept" }); + var retryParams = new JsonObject + { + ["name"] = "elicit-tool", + ["arguments"] = new JsonObject { ["message"] = "Please confirm" }, + ["requestState"] = requestState, + ["inputResponses"] = new JsonObject + { + [inputKey] = JsonSerializer.SerializeToNode(inputResponse, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InputResponse))) + }, + }; + + using var retryResponse = await PostJsonRpcAsync(Request("tools/call", retryParams.ToJsonString())); + + // The session was deleted, so we should get a 404 with a JSON-RPC error. + Assert.Equal(HttpStatusCode.NotFound, retryResponse.StatusCode); + Assert.Equal("application/json", retryResponse.Content.Headers.ContentType?.MediaType); + } + + /// + /// Regression test for a CI hang where the server-side MRTR backcompat resolver routed its + /// outgoing roots/list request through the session-level transport, which silently + /// dropped the message when the client's GET stream had not been established yet. The + /// outgoing request must instead go through the POST's response stream (the request's + /// ) so it + /// reaches the client without depending on the GET stream at all. + /// + /// This test deliberately never opens a GET stream - it only POSTs the initialize, the + /// initialized notification, the tools/call, and the roots/list response. If the + /// server falls back to _transport.SendMessageAsync, the test times out instead of + /// reading the expected roots/list SSE event off the tools/call POST response. + /// + [Fact] + public async Task BackcompatResolver_SendsServerRequestOverPostStream_WithoutGetStream() + { + // Configure a server that does NOT pin DRAFT-2026-v1 so it can negotiate the current + // protocol with a legacy client. The backcompat resolver path only runs when the + // negotiated version is not DRAFT-2026-v1. + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = nameof(MrtrProtocolTests), + Version = "1", + }; + }).WithTools([ + McpServerTool.Create( + static string (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("roots", out var response)) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots; + return $"roots-ok:{roots?.FirstOrDefault()?.Name}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "roots-state"); + }, + new McpServerToolCreateOptions + { + Name = "backcompat-roots-tool", + Description = "Throws InputRequiredException so the server's backcompat resolver issues a roots/list", + }), + ]).WithHttpTransport(); + + _app = Builder.Build(); + _app.MapMcp(); + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + + // Initialize with the current (non-draft) protocol so the server's backcompat resolver runs. + var initJson = """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{"roots":{}},"clientInfo":{"name":"BackcompatTestClient","version":"1.0.0"}}} + """; + + string sessionId; + using (var initResponse = await PostJsonRpcAsync(initJson)) + { + var initRpcResponse = await AssertSingleSseResponseAsync(initResponse); + Assert.NotNull(initRpcResponse.Result); + Assert.Equal("2025-11-25", initRpcResponse.Result["protocolVersion"]?.GetValue()); + + sessionId = Assert.Single(initResponse.Headers.GetValues("mcp-session-id")); + } + + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + HttpClient.DefaultRequestHeaders.Remove("MCP-Protocol-Version"); + HttpClient.DefaultRequestHeaders.Add("MCP-Protocol-Version", "2025-11-25"); + + // Send the initialized notification. + using (var initializedResponse = await PostJsonRpcAsync( + """{"jsonrpc":"2.0","method":"notifications/initialized"}""")) + { + Assert.True(initializedResponse.IsSuccessStatusCode); + } + + _lastRequestId = 1; + + // POST the tools/call and start reading the response SSE stream. We deliberately do NOT + // open a GET stream - the server-to-client roots/list must be delivered on this POST's + // response. Use HttpCompletionOption.ResponseHeadersRead so the POST returns as soon as + // the response headers arrive instead of waiting for the SSE stream to close. + var callRequest = new HttpRequestMessage(HttpMethod.Post, (string?)null) + { + Content = JsonContent(CallTool("backcompat-roots-tool")), + }; + callRequest.Content.Headers.Add("Mcp-Method", "tools/call"); + callRequest.Content.Headers.Add("Mcp-Name", "backcompat-roots-tool"); + + using var callResponse = await HttpClient.SendAsync( + callRequest, + HttpCompletionOption.ResponseHeadersRead, + TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, callResponse.StatusCode); + Assert.Equal("text/event-stream", callResponse.Content.Headers.ContentType?.MediaType); + + var sseEvents = ReadSseAsync(callResponse.Content) + .GetAsyncEnumerator(TestContext.Current.CancellationToken); + + try + { + // First SSE event on this POST should be the server-initiated roots/list request. + Assert.True(await sseEvents.MoveNextAsync(), + "Server did not send a roots/list request on the tools/call POST response stream. " + + "If this hangs/times out, the MRTR backcompat resolver is routing the outgoing request " + + "through the session-level transport instead of the POST's RelatedTransport."); + + var rootsRequestNode = JsonNode.Parse(sseEvents.Current) as JsonObject; + Assert.NotNull(rootsRequestNode); + Assert.Equal("roots/list", rootsRequestNode["method"]?.GetValue()); + var rootsRequestId = rootsRequestNode["id"]; + Assert.NotNull(rootsRequestId); + + // POST the roots/list response on a separate connection. The server's pending + // RequestRootsAsync await will complete and the backcompat resolver will retry the tool. + var rootsIdLiteral = rootsRequestId.ToJsonString(); + var rootsResponseJson = + "{\"jsonrpc\":\"2.0\",\"id\":" + rootsIdLiteral + + ",\"result\":{\"roots\":[{\"uri\":\"file:///workspace\",\"name\":\"Workspace\"}]}}"; + using (var rootsResponseHttp = await PostJsonRpcAsync(rootsResponseJson)) + { + Assert.True(rootsResponseHttp.IsSuccessStatusCode); + } + + // Next SSE event on the original POST should be the final tools/call response. + Assert.True(await sseEvents.MoveNextAsync(), "Server did not return the final tools/call response."); + var finalResponse = JsonSerializer.Deserialize(sseEvents.Current, GetJsonTypeInfo()); + Assert.NotNull(finalResponse); + Assert.NotNull(finalResponse.Result); + + var content = finalResponse.Result["content"]?.AsArray(); + Assert.NotNull(content); + var firstContent = Assert.Single(content); + Assert.Equal("roots-ok:Workspace", firstContent?["text"]?.GetValue()); + } + finally + { + await sseEvents.DisposeAsync(); + } + } + + // --- Helpers --- + + private static StringContent JsonContent(string json) => new(json, Encoding.UTF8, "application/json"); + private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + + private static async IAsyncEnumerable ReadSseAsync(HttpContent responseContent) + { + var responseStream = await responseContent.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(responseStream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + Assert.Equal("message", sseItem.EventType); + yield return sseItem.Data; + } + } + + private static async Task AssertSingleSseResponseAsync(HttpResponseMessage response) + { + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("text/event-stream", response.Content.Headers.ContentType?.MediaType); + + var sseItem = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken)); + var jsonRpcResponse = JsonSerializer.Deserialize(sseItem, GetJsonTypeInfo()); + + Assert.NotNull(jsonRpcResponse); + return jsonRpcResponse; + } + + private Task PostJsonRpcAsync(string json) + { + var content = JsonContent(json); + + // DRAFT-2026-v1 requires Mcp-Method and (for tools/call) Mcp-Name headers per SEP-2243. + // Parse the body to derive them and attach to this request only. + var bodyNode = JsonNode.Parse(json); + if (bodyNode is JsonObject obj) + { + if (obj["method"]?.GetValue() is { } method) + { + content.Headers.Add("Mcp-Method", method); + + if (obj["params"] is JsonObject paramsObj) + { + string? mcpName = method switch + { + "tools/call" or "prompts/get" => paramsObj["name"]?.GetValue(), + "resources/read" => paramsObj["uri"]?.GetValue(), + _ => null, + }; + if (mcpName is not null) + { + content.Headers.Add("Mcp-Name", mcpName); + } + } + } + } + + return HttpClient.PostAsync("", content, TestContext.Current.CancellationToken); + } + + private long _lastRequestId = 1; + + private string Request(string method, string parameters = "{}") + { + var id = Interlocked.Increment(ref _lastRequestId); + return $$""" + {"jsonrpc":"2.0","id":{{id}},"method":"{{method}}","params":{{parameters}}} + """; + } + + private string CallTool(string toolName, string arguments = "{}") => + Request("tools/call", $$""" + {"name":"{{toolName}}","arguments":{{arguments}}} + """); + + /// + /// Initialize a session requesting the experimental protocol version that enables MRTR. + /// + private async Task InitializeWithMrtrAsync() + { + var initJson = """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"DRAFT-2026-v1","capabilities":{"sampling":{},"elicitation":{},"roots":{}},"clientInfo":{"name":"MrtrTestClient","version":"1.0.0"}}} + """; + + using var response = await PostJsonRpcAsync(initJson); + var rpcResponse = await AssertSingleSseResponseAsync(response); + Assert.NotNull(rpcResponse.Result); + + // Verify the server negotiated to the experimental version + var protocolVersion = rpcResponse.Result["protocolVersion"]?.GetValue(); + Assert.Equal("DRAFT-2026-v1", protocolVersion); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + + // Set the MCP-Protocol-Version header for subsequent requests + HttpClient.DefaultRequestHeaders.Remove("MCP-Protocol-Version"); + HttpClient.DefaultRequestHeaders.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + + // Reset request ID counter since initialize used ID 1 + _lastRequestId = 1; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs index 98cc5971a..ea4187a95 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs @@ -159,6 +159,34 @@ public async Task RunConformanceTest_HttpCustomHeaderServerValidation() $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); } + // SEP-2322 (Multi Round-Trip Requests / IncompleteResult) conformance scenarios. + // The csharp-sdk ConformanceServer surfaces the matching tools/prompts via + // ConformanceServer.Tools.IncompleteResultTools and ConformanceServer.Prompts.IncompleteResultPrompts. + // Each scenario uses the conformance harness's RawMcpSession, which negotiates DRAFT-2026-v1 + // so the csharp-sdk emits InputRequiredResult on the wire. These tests skip until the + // upstream conformance package ships with SEP-2322 scenarios + // (https://github.com/modelcontextprotocol/conformance/pull/188). + [Theory] + [InlineData("incomplete-result-basic-elicitation")] + [InlineData("incomplete-result-basic-sampling")] + [InlineData("incomplete-result-basic-list-roots")] + [InlineData("incomplete-result-request-state")] + [InlineData("incomplete-result-multiple-input-requests")] + [InlineData("incomplete-result-multi-round")] + [InlineData("incomplete-result-missing-input-response")] + [InlineData("incomplete-result-non-tool-request")] + public async Task RunMrtrConformanceTest(string scenario) + { + Assert.SkipWhen(!NodeHelpers.IsNodeInstalled(), "Node.js is not installed. Skipping conformance tests."); + Assert.SkipWhen(!NodeHelpers.HasMrtrScenarios(), "SEP-2322 MRTR conformance scenarios not yet available in the published @modelcontextprotocol/conformance package."); + + var result = await RunConformanceTestsAsync( + $"server --url {fixture.ServerUrl} --scenario {scenario}"); + + Assert.True(result.Success, + $"MRTR conformance test '{scenario}' failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); + } + private async Task<(bool Success, string Output, string Error)> RunConformanceTestsAsync(string arguments) { var startInfo = NodeHelpers.ConformanceTestStartInfo(arguments); diff --git a/tests/ModelContextProtocol.ConformanceServer/Program.cs b/tests/ModelContextProtocol.ConformanceServer/Program.cs index 017ec235f..f30d58a4d 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Program.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Program.cs @@ -31,6 +31,7 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide .WithHttpTransport() .WithDistributedCacheEventStreamStore() .WithTools() + .WithTools() .WithTools([ConformanceTools.CreateJsonSchema202012Tool()]) .WithRequestFilters(filters => filters.AddCallToolFilter(next => async (request, cancellationToken) => { @@ -47,6 +48,7 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide return result; })) .WithPrompts() + .WithPrompts() .WithResources() .WithSubscribeToResourcesHandler(async (ctx, ct) => { diff --git a/tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs b/tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs new file mode 100644 index 000000000..4dfe6dfb0 --- /dev/null +++ b/tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs @@ -0,0 +1,68 @@ +#pragma warning disable MCPEXP001 // MRTR (SEP-2322) is experimental. + +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Text.Json; + +namespace ConformanceServer.Prompts; + +/// +/// Prompt implementing the SEP-2322 D1 conformance scenario (incomplete-result-non-tool-request), +/// proving that prompts/get can return an just like +/// tools/call. +/// +[McpServerPromptType] +public sealed class IncompleteResultPrompts +{ + [McpServerPrompt(Name = "test_incomplete_result_prompt")] + [Description("SEP-2322 D1: prompts/get returns IncompleteResult until user_context is supplied.")] + public static GetPromptResult IncompleteResultPrompt(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_context", out var response)) + { + var elicit = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + var contextValue = TryReadString(elicit?.Content, "context") ?? "(unknown)"; + return new GetPromptResult + { + Description = "Prompt customized with elicited user context.", + Messages = + [ + new PromptMessage + { + Role = Role.User, + Content = new TextContentBlock { Text = $"Please continue using context: {contextValue}" }, + }, + ], + }; + } + + throw new InputRequiredException( + new Dictionary + { + ["user_context"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What context should the prompt use?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["context"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["context"], + }, + }), + }); + } + + private static string? TryReadString(IDictionary? content, string key) + { + if (content is null || !content.TryGetValue(key, out var element)) + { + return null; + } + return element.ValueKind == JsonValueKind.String ? element.GetString() : element.ToString(); + } +} diff --git a/tests/ModelContextProtocol.ConformanceServer/Tools/IncompleteResultTools.cs b/tests/ModelContextProtocol.ConformanceServer/Tools/IncompleteResultTools.cs new file mode 100644 index 000000000..caf91237a --- /dev/null +++ b/tests/ModelContextProtocol.ConformanceServer/Tools/IncompleteResultTools.cs @@ -0,0 +1,279 @@ +#pragma warning disable MCPEXP001 // MRTR (SEP-2322) is experimental. + +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ConformanceServer.Tools; + +/// +/// Tools implementing the SEP-2322 (MRTR / IncompleteResult) conformance scenarios from +/// incomplete-result.ts in the conformance test suite. All tools use the +/// API so they work both in stateful sessions with +/// MRTR-aware clients and in legacy-resolve mode (the SDK will translate exceptions to the +/// proper wire shape based on negotiated protocol version). +/// +[McpServerToolType] +public sealed class IncompleteResultTools +{ + // ──── A1: Basic Elicitation ───────────────────────────────────────────── + [McpServerTool(Name = "test_tool_with_elicitation")] + [Description("SEP-2322 A1: returns IncompleteResult with elicitation/create keyed 'user_name'.")] + public static CallToolResult ToolWithElicitation(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_name", out var response)) + { + var elicit = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + var name = TryReadString(elicit?.Content, "name") ?? "world"; + return TextResult($"Hello, {name}!"); + } + + throw new InputRequiredException( + new Dictionary + { + ["user_name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + }); + } + + // ──── A2: Basic Sampling ──────────────────────────────────────────────── + [McpServerTool(Name = "test_incomplete_result_sampling")] + [Description("SEP-2322 A2: returns IncompleteResult with sampling/createMessage keyed 'capital_question'.")] + public static CallToolResult ToolWithSampling(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("capital_question", out var response)) + { + var text = response.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content?.OfType().FirstOrDefault()?.Text ?? "(no text)"; + return TextResult($"Sampling said: {text}"); + } + + throw new InputRequiredException( + new Dictionary + { + ["capital_question"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "What is the capital of France?" }], + }, + ], + MaxTokens = 100, + }), + }); + } + + // ──── A3: Basic ListRoots ─────────────────────────────────────────────── + [McpServerTool(Name = "test_incomplete_result_list_roots")] + [Description("SEP-2322 A3: returns IncompleteResult with roots/list keyed 'client_roots'.")] + public static CallToolResult ToolWithListRoots(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("client_roots", out var response)) + { + var count = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots?.Count ?? 0; + return TextResult($"Got {count} root(s) from the client."); + } + + throw new InputRequiredException( + new Dictionary + { + ["client_roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()), + }); + } + + // ──── B1: requestState round-trip ─────────────────────────────────────── + private const string RequestStateToken = "mrtr-conformance-state-v1"; + + [McpServerTool(Name = "test_incomplete_result_request_state")] + [Description("SEP-2322 B1: round-trips a requestState string; R2 echoes 'state-ok' on success.")] + public static CallToolResult ToolWithRequestState(RequestContext context) + { + if (context.Params!.RequestState is { } state) + { + if (state != RequestStateToken) + { + return TextResult("state-mismatch: client echoed an unexpected requestState"); + } + return TextResult("state-ok: server received and validated the echoed requestState"); + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["ok"] = new ElicitRequestParams.BooleanSchema(), + }, + Required = ["ok"], + }, + }), + }, + requestState: RequestStateToken); + } + + // ──── B2: Multiple input requests in one round ────────────────────────── + [McpServerTool(Name = "test_incomplete_result_multiple_inputs")] + [Description("SEP-2322 B2: returns 3 simultaneous inputRequests (elicit + sampling + roots) plus requestState.")] + public static CallToolResult ToolWithMultipleInputs(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && responses.Count >= 3) + { + return TextResult("multiple-inputs-ok: received elicit + sampling + roots responses"); + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + ["greeting"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Generate a greeting" }], + }, + ], + MaxTokens = 50, + }), + ["client_roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()), + }, + requestState: "multi-input-state"); + } + + // ──── B3: Multi-round (R1 -> incomplete, R2 -> incomplete (new state), R3 -> complete) ───── + [McpServerTool(Name = "test_incomplete_result_multi_round")] + [Description("SEP-2322 B3: three-round flow whose requestState changes between rounds.")] + public static CallToolResult ToolWithMultiRound(RequestContext context) + { + var state = context.Params!.RequestState; + if (state is null) + { + // Round 1: elicit name. + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["step1"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Step 1: What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + }, + requestState: "round-1"); + } + + if (state == "round-1") + { + // Round 2: elicit color (new state). + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["step2"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Step 2: What is your favorite color?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["color"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["color"], + }, + }), + }, + requestState: "round-2"); + } + + // Round 3: complete. + return TextResult("multi-round-ok"); + } + + // ──── C1: Missing/wrong inputResponses key - re-request rather than error ──── + [McpServerTool(Name = "test_incomplete_result_elicitation")] + [Description("SEP-2322 C1: re-requests missing inputResponses key instead of erroring.")] + public static CallToolResult ToolForMissingResponse(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_name", out var response)) + { + var elicit = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + var name = TryReadString(elicit?.Content, "name") ?? "world"; + return TextResult($"Hello, {name}!"); + } + + // Either no inputResponses or wrong key - re-request via a fresh InputRequiredResult + // (per SEP-2322 recommendation in scenario C1). + throw new InputRequiredException( + new Dictionary + { + ["user_name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + }); + } + + private static CallToolResult TextResult(string text) => new() + { + Content = [new TextContentBlock { Text = text }], + }; + + private static string? TryReadString(IDictionary? content, string key) + { + if (content is null || !content.TryGetValue(key, out var element)) + { + return null; + } + return element.ValueKind == JsonValueKind.String ? element.GetString() : element.ToString(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index 262efbd40..749ef51eb 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -587,6 +587,14 @@ public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) Assert.Equal(protocolVersion ?? "2025-11-25", client.NegotiatedProtocolVersion); } + [Fact] + public async Task ReturnsNegotiatedProtocolVersion_WithExperimentalProtocol() + { + Server.ServerOptions.ProtocolVersion = "DRAFT-2026-v1"; + await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = "DRAFT-2026-v1" }); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + } + [Fact] public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionInvocation_ClientHandlesSamplingWithIChatClient() { diff --git a/tests/ModelContextProtocol.Tests/Client/MrtrIntegrationTests.cs b/tests/ModelContextProtocol.Tests/Client/MrtrIntegrationTests.cs new file mode 100644 index 000000000..90864d393 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/MrtrIntegrationTests.cs @@ -0,0 +1,570 @@ +#if !NET472 +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.IO.Pipelines; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Client; + +/// +/// Edge-case and guardrail tests for MRTR over in-memory pipe transport. These focus on +/// scenarios not easily covered by +/// which provides broad happy-path coverage across StreamableHttp, SSE, and Stateless transports. +/// +public class MrtrIntegrationTests : ClientServerTestBase +{ + private readonly ServerMessageTracker _messageTracker = new(); + + public MrtrIntegrationTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.AddLogging(builder => builder.SetMinimumLevel(LogLevel.Debug)); + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + _messageTracker.AddFilters(options.Filters.Message); + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + return $"{result.Action}:{result.Content?.FirstOrDefault().Value}"; + }, + new McpServerToolCreateOptions + { + Name = "elicitation-tool", + Description = "A tool that requests elicitation from the client" + }), + McpServerTool.Create( + async (McpServer server) => + { + // Attempt to send a JsonRpcRequest via SendMessageAsync - should always throw + // since requests must go through SendRequestAsync for response correlation. + try + { + await server.SendMessageAsync(new JsonRpcRequest + { + Id = new RequestId(999), + Method = RequestMethods.ElicitationCreate, + Params = JsonSerializer.SerializeToNode(new ElicitRequestParams + { + Message = "Bypass attempt", + RequestedSchema = new() + }, McpJsonUtilities.DefaultOptions) + }); + return "NOT BLOCKED - expected InvalidOperationException"; + } + catch (InvalidOperationException ex) + { + return $"blocked:{ex.Message}"; + } + }, + new McpServerToolCreateOptions + { + Name = "sendmessage-bypass-tool", + Description = "A tool that attempts to bypass MRTR via SendMessageAsync" + }) + ]); + } + + [Fact] + public async Task ClientHandlerException_DuringMrtrInputResolution_SurfacesToCaller() + { + // When the CLIENT's elicitation handler throws during MRTR input resolution, + // the retry never reaches the server - the server's handler remains suspended + // on ElicitAsync(). The exception should surface to the CallToolAsync caller, + // and the server's orphaned handler should be cleaned up on disposal. + // This is a fundamental MRTR limitation: the client has no channel to communicate + // input resolution failures back to the server. + StartServer(); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + { + throw new InvalidOperationException("Client-side elicitation failure"); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + // The client handler throws during input resolution, so the exception + // escapes ResolveInputRequestAsync and surfaces directly to the caller. + var ex = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("elicitation-tool", + new Dictionary { ["message"] = "Will fail" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Client-side elicitation failure", ex.Message); + + // Dispose the server to trigger cleanup of the orphaned MRTR continuation. + // The server should cancel the handler suspended on ElicitAsync() and log + // the cancelled continuation at Debug level. + await Server.DisposeAsync(); + + Assert.Contains(MockLoggerProvider.LogMessages, m => + m.LogLevel == LogLevel.Debug && + m.Message.Contains("Cancelled") && + m.Message.Contains("MRTR continuation")); + } + + [Fact] + public async Task SendMessageAsync_WithJsonRpcRequest_ThrowsAlways() + { + // SendMessageAsync should throw InvalidOperationException if the message is a + // JsonRpcRequest, regardless of MRTR state. Use SendRequestAsync for requests. + StartServer(); + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + var result = await client.CallToolAsync("sendmessage-bypass-tool", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.StartsWith("blocked:", text); + Assert.Contains("SendMessageAsync", text); + Assert.Contains("SendRequestAsync", text); + } + + [Fact] + public async Task LegacyRequestOnMrtrSession_LogsWarning() + { + // This test simulates a non-compliant server that negotiates MRTR + // but sends legacy elicitation/create JSON-RPC requests instead of + // using InputRequiredResult. The client should handle it but log a warning. + StartServer(); // Required for base class DisposeAsync cleanup + var clientToServer = new Pipe(); + var serverToClient = new Pipe(); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + clientOptions.Handlers.SamplingHandler = (request, progress, ct) => + new ValueTask(new CreateMessageResult + { + Content = [new TextContentBlock { Text = "sampled" }], + Model = "test-model" + }); + + // Start the client task - it will send initialize and block waiting for response + var clientTask = McpClient.CreateAsync( + new StreamClientTransport( + clientToServer.Writer.AsStream(), + serverToClient.Reader.AsStream(), + LoggerFactory), + clientOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + + // Simulate server: read initialize request, respond with experimental version + var serverReader = new StreamReader(clientToServer.Reader.AsStream()); + var serverWriter = serverToClient.Writer.AsStream(); + + // Read the initialize request from client + var initLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(initLine); + var initRequest = JsonSerializer.Deserialize(initLine, McpJsonUtilities.DefaultOptions); + Assert.NotNull(initRequest); + Assert.Equal("initialize", initRequest.Method); + + // Respond with experimental protocol version (MRTR negotiated) + var initResponse = new JsonRpcResponse + { + Id = initRequest.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ServerCapabilities(), + ServerInfo = new Implementation { Name = "MockMrtrServer", Version = "1.0" } + }, McpJsonUtilities.DefaultOptions), + }; + await WriteJsonRpcAsync(serverWriter, initResponse); + + // Read the initialized notification from client + var initializedLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(initializedLine); + + // Client is now connected with MRTR negotiated + await using var client = await clientTask; + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + // Now simulate the non-compliant server sending a legacy elicitation/create request + var legacyRequest = new JsonRpcRequest + { + Id = new RequestId(42), + Method = RequestMethods.ElicitationCreate, + Params = JsonSerializer.SerializeToNode(new ElicitRequestParams + { + Message = "Legacy elicitation from non-compliant server", + RequestedSchema = new() + }, McpJsonUtilities.DefaultOptions), + }; + await WriteJsonRpcAsync(serverWriter, legacyRequest); + + // Read the client's response to the legacy request + var responseLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(responseLine); + var clientResponse = JsonSerializer.Deserialize(responseLine, McpJsonUtilities.DefaultOptions); + Assert.NotNull(clientResponse); + Assert.Equal(new RequestId(42), clientResponse.Id); + + // Verify the client handled the request (returned ElicitResult) + var elicitResult = JsonSerializer.Deserialize(clientResponse.Result, McpJsonUtilities.DefaultOptions); + Assert.NotNull(elicitResult); + Assert.Equal("accept", elicitResult.Action); + + // Verify the warning was logged + Assert.Contains(MockLoggerProvider.LogMessages, m => + m.LogLevel == LogLevel.Warning && + m.Message.Contains("elicitation/create") && + m.Message.Contains("MRTR")); + + // Clean up + clientToServer.Writer.Complete(); + serverToClient.Writer.Complete(); + } + + [Fact] + public async Task IncompleteResultOnNonMrtrSession_LogsWarning() + { + // This test simulates a non-compliant server that sends an InputRequiredResult + // to a client that did NOT negotiate MRTR. The client should still process it + // (resilience), but log a warning about the unexpected protocol behavior. + StartServer(); // Required for base class DisposeAsync cleanup + var clientToServer = new Pipe(); + var serverToClient = new Pipe(); + + // Client does NOT set DRAFT-2026-v1 - standard protocol only + var clientOptions = new McpClientOptions(); + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["confirmed"] = JsonDocument.Parse("\"yes\"").RootElement.Clone() + } + }); + + // Start the client task - it will send initialize and block waiting for response + var clientTask = McpClient.CreateAsync( + new StreamClientTransport( + clientToServer.Writer.AsStream(), + serverToClient.Reader.AsStream(), + LoggerFactory), + clientOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + + var serverReader = new StreamReader(clientToServer.Reader.AsStream()); + var serverWriter = serverToClient.Writer.AsStream(); + + // Read the initialize request from client + var initLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(initLine); + var initRequest = JsonSerializer.Deserialize(initLine, McpJsonUtilities.DefaultOptions); + Assert.NotNull(initRequest); + Assert.Equal("initialize", initRequest.Method); + + // Respond with standard protocol version (no MRTR) + var initResponse = new JsonRpcResponse + { + Id = initRequest.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "2025-03-26", + Capabilities = new ServerCapabilities { Tools = new() }, + ServerInfo = new Implementation { Name = "NonCompliantServer", Version = "1.0" } + }, McpJsonUtilities.DefaultOptions), + }; + await WriteJsonRpcAsync(serverWriter, initResponse); + + // Read the initialized notification from client + var initializedLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(initializedLine); + + // Client is now connected with standard protocol (no MRTR) + await using var client = await clientTask; + Assert.Equal("2025-03-26", client.NegotiatedProtocolVersion); + + // Start a background task to handle the client's tools/call request + var cancellationToken = TestContext.Current.CancellationToken; + var serverLoop = Task.Run(async () => + { + // Read tools/call request from client + var callLine = await serverReader.ReadLineAsync(cancellationToken); + Assert.NotNull(callLine); + var callRequest = JsonSerializer.Deserialize(callLine, McpJsonUtilities.DefaultOptions); + Assert.NotNull(callRequest); + Assert.Equal("tools/call", callRequest.Method); + + // Non-compliant server sends InputRequiredResult on standard protocol session! + var InputRequiredResult = new JsonObject + { + ["resultType"] = "input_required", + ["inputRequests"] = new JsonObject + { + ["confirm_1"] = JsonSerializer.SerializeToNode( + InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Unexpected elicitation from non-compliant server", + RequestedSchema = new() + }), McpJsonUtilities.DefaultOptions) + }, + ["requestState"] = "non-mrtr-state" + }; + + var incompleteResponse = new JsonRpcResponse + { + Id = callRequest.Id, + Result = InputRequiredResult, + }; + await WriteJsonRpcAsync(serverWriter, incompleteResponse); + + // Read the retry request with inputResponses from client + var retryLine = await serverReader.ReadLineAsync(cancellationToken); + Assert.NotNull(retryLine); + var retryRequest = JsonSerializer.Deserialize(retryLine, McpJsonUtilities.DefaultOptions); + Assert.NotNull(retryRequest); + Assert.Equal("tools/call", retryRequest.Method); + + // Verify the retry contains inputResponses and requestState + var retryParams = retryRequest.Params as JsonObject; + Assert.NotNull(retryParams); + Assert.NotNull(retryParams["inputResponses"]); + Assert.Equal("non-mrtr-state", retryParams["requestState"]?.GetValue()); + + // Now respond with a normal result + var normalResult = new JsonRpcResponse + { + Id = retryRequest.Id, + Result = JsonSerializer.SerializeToNode(new CallToolResult + { + Content = [new TextContentBlock { Text = "completed-without-mrtr" }] + }, McpJsonUtilities.DefaultOptions), + }; + await WriteJsonRpcAsync(serverWriter, normalResult); + }, cancellationToken); + + // Client calls the tool - the non-compliant server will send InputRequiredResult + var response = await client.SendRequestAsync( + new JsonRpcRequest + { + Method = "tools/call", + Params = JsonSerializer.SerializeToNode(new CallToolRequestParams + { + Name = "any-tool", + }, McpJsonUtilities.DefaultOptions) + }, + cancellationToken); + + await serverLoop; + + Assert.NotNull(response.Result); + var result = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.DefaultOptions); + Assert.NotNull(result); + var content = Assert.Single(result.Content); + Assert.Equal("completed-without-mrtr", Assert.IsType(content).Text); + + // Verify the warning was logged about InputRequiredResult on non-MRTR session + Assert.Contains(MockLoggerProvider.LogMessages, m => + m.LogLevel == LogLevel.Warning && + m.Message.Contains("InputRequiredResult") && + m.Message.Contains("did not negotiate MRTR")); + + // Clean up + clientToServer.Writer.Complete(); + serverToClient.Writer.Complete(); + } + + [Fact] + public async Task IncompleteResultRetry_OmittingRequestState_StripsStaleStateFromRetryParams() + { + // Regression test for #1458 review feedback: when the server returns InputRequiredResult + // with requestState on round 1 and then InputRequiredResult WITHOUT requestState on round 2, + // the client's third retry must NOT carry the stale round-1 requestState forward via the + // params deep clone. Without the fix, the third retry's params contain {"requestState": "round1-state"} + // even though the round-2 InputRequiredResult cleared it. + StartServer(); // base-class disposal hook + var clientToServer = new Pipe(); + var serverToClient = new Pipe(); + + var clientOptions = new McpClientOptions(); + clientOptions.Handlers.ElicitationHandler = (_, _) => + new ValueTask(new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["confirmed"] = JsonDocument.Parse("\"yes\"").RootElement.Clone() + } + }); + + var clientTask = McpClient.CreateAsync( + new StreamClientTransport( + clientToServer.Writer.AsStream(), + serverToClient.Reader.AsStream(), + LoggerFactory), + clientOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + + var serverReader = new StreamReader(clientToServer.Reader.AsStream()); + var serverWriter = serverToClient.Writer.AsStream(); + + // Initialize handshake - negotiate DRAFT-2026-v1 so the client treats InputRequiredResult as MRTR. + var initLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(initLine); + var initRequest = JsonSerializer.Deserialize(initLine, McpJsonUtilities.DefaultOptions); + Assert.NotNull(initRequest); + Assert.Equal("initialize", initRequest.Method); + + var initResponse = new JsonRpcResponse + { + Id = initRequest.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ServerCapabilities { Tools = new() }, + ServerInfo = new Implementation { Name = "MrtrServer", Version = "1.0" } + }, McpJsonUtilities.DefaultOptions), + }; + await WriteJsonRpcAsync(serverWriter, initResponse); + + var initializedLine = await serverReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(initializedLine); + + await using var client = await clientTask; + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var cancellationToken = TestContext.Current.CancellationToken; + + // Capture the retry payloads sent by the client so we can inspect them after the call completes. + JsonObject? retry1Params = null; + JsonObject? retry2Params = null; + + var serverLoop = Task.Run(async () => + { + // --- Round 1: receive original tools/call, respond with InputRequiredResult + requestState="round1-state". + var call1Line = await serverReader.ReadLineAsync(cancellationToken); + Assert.NotNull(call1Line); + var call1 = JsonSerializer.Deserialize(call1Line, McpJsonUtilities.DefaultOptions); + Assert.NotNull(call1); + Assert.Equal("tools/call", call1.Method); + + var round1Result = new JsonObject + { + ["resultType"] = "input_required", + ["inputRequests"] = new JsonObject + { + ["q1"] = JsonSerializer.SerializeToNode( + InputRequest.ForElicitation(new ElicitRequestParams { Message = "round1", RequestedSchema = new() }), + McpJsonUtilities.DefaultOptions), + }, + ["requestState"] = "round1-state", + }; + await WriteJsonRpcAsync(serverWriter, new JsonRpcResponse { Id = call1.Id, Result = round1Result }); + + // --- Round 2: receive first retry (should include requestState="round1-state" + inputResponses). + var call2Line = await serverReader.ReadLineAsync(cancellationToken); + Assert.NotNull(call2Line); + var call2 = JsonSerializer.Deserialize(call2Line, McpJsonUtilities.DefaultOptions); + Assert.NotNull(call2); + retry1Params = call2.Params as JsonObject; + + // Respond with another InputRequiredResult - this time WITHOUT requestState - to force the + // client to clear any stale state on the next retry params clone. + var round2Result = new JsonObject + { + ["resultType"] = "input_required", + ["inputRequests"] = new JsonObject + { + ["q2"] = JsonSerializer.SerializeToNode( + InputRequest.ForElicitation(new ElicitRequestParams { Message = "round2", RequestedSchema = new() }), + McpJsonUtilities.DefaultOptions), + }, + // Intentionally NO "requestState" key. + }; + await WriteJsonRpcAsync(serverWriter, new JsonRpcResponse { Id = call2.Id, Result = round2Result }); + + // --- Round 3: receive second retry - assertion target. Must NOT contain "requestState". + var call3Line = await serverReader.ReadLineAsync(cancellationToken); + Assert.NotNull(call3Line); + var call3 = JsonSerializer.Deserialize(call3Line, McpJsonUtilities.DefaultOptions); + Assert.NotNull(call3); + retry2Params = call3.Params as JsonObject; + + // Final success response so the client's call completes cleanly. + await WriteJsonRpcAsync(serverWriter, new JsonRpcResponse + { + Id = call3.Id, + Result = JsonSerializer.SerializeToNode(new CallToolResult + { + Content = [new TextContentBlock { Text = "done" }] + }, McpJsonUtilities.DefaultOptions), + }); + }, cancellationToken); + + var response = await client.SendRequestAsync( + new JsonRpcRequest + { + Method = "tools/call", + Params = JsonSerializer.SerializeToNode(new CallToolRequestParams { Name = "any-tool" }, McpJsonUtilities.DefaultOptions), + }, + cancellationToken); + + await serverLoop; + + // Sanity check the final result reached us. + Assert.NotNull(response.Result); + var result = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.DefaultOptions); + Assert.NotNull(result); + Assert.Equal("done", Assert.IsType(Assert.Single(result.Content)).Text); + + // The first retry must carry requestState="round1-state". + Assert.NotNull(retry1Params); + Assert.NotNull(retry1Params!["inputResponses"]); + Assert.Equal("round1-state", retry1Params["requestState"]?.GetValue()); + + // The second retry must NOT carry a stale requestState. Pre-fix, the deep clone of the + // round-1 request kept "round1-state" in paramsObj because the client only OVERWROTE it + // when InputRequiredResult.RequestState was non-null. With the fix, it explicitly removes + // the key whenever the server's new InputRequiredResult clears it. + Assert.NotNull(retry2Params); + Assert.NotNull(retry2Params!["inputResponses"]); + Assert.False(retry2Params.ContainsKey("requestState"), + "Retry params must not carry a stale requestState from the previous round."); + + clientToServer.Writer.Complete(); + serverToClient.Writer.Complete(); + } + + private static async Task WriteJsonRpcAsync(Stream writer, JsonRpcMessage message) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(message, McpJsonUtilities.DefaultOptions); + await writer.WriteAsync(bytes, TestContext.Current.CancellationToken); + await writer.WriteAsync("\n"u8.ToArray(), TestContext.Current.CancellationToken); + await writer.FlushAsync(TestContext.Current.CancellationToken); + } +} + +#endif diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs index 171c6bead..a39d4896f 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs @@ -87,8 +87,14 @@ public async Task AddIncomingMessageFilter_Intercepts_Request_Messages() await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - // The message filter should intercept JsonRpcRequest messages - Assert.Contains("JsonRpcRequest", messageTypes); + // The message filter should intercept JsonRpcRequest messages. + // Use strict counts so a regression that invokes the filter pipeline more than once per + // incoming message (analogous to the SendRequestAsync double-wrap regression on the outgoing + // side) would fail this test instead of slipping through Assert.Contains. + // A single ListToolsAsync drives three server-bound messages: initialize (request), + // notifications/initialized (notification), and tools/list (request). + Assert.Equal(2, messageTypes.Count(m => m == nameof(JsonRpcRequest))); + Assert.Equal(1, messageTypes.Count(m => m == nameof(JsonRpcNotification))); } [Fact] @@ -142,6 +148,13 @@ public async Task AddIncomingMessageFilter_Multiple_Filters_Execute_In_Order() Assert.True(idx1Before < idx2Before); Assert.True(idx2Before < idx2After); Assert.True(idx2After < idx1After); + + // Verify each filter ran exactly once per incoming message (initialize + notifications/initialized + tools/list). + // Strict counts catch regressions where the incoming filter pipeline gets invoked more than once per message. + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter1 before")); + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter2 before")); + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter2 after")); + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter1 after")); } [Fact] @@ -372,15 +385,20 @@ public async Task AddOutgoingMessageFilter_Sees_Responses_Notifications_And_Requ await client.CallToolAsync("sampling-tool", new Dictionary { ["prompt"] = "Hello" }, cancellationToken: TestContext.Current.CancellationToken); + // Exact counts catch regressions where the outgoing filter pipeline gets applied more than once + // per outbound message (e.g., SendRequestAsync double-wrapping SendToRelatedTransportAsync). + Assert.Equal(1, observedMessages.Count(m => m == "initialize")); + Assert.Equal(2, observedMessages.Count(m => m == "progress")); // ProgressTool sends two NotifyProgressAsync calls + Assert.Equal(2, observedMessages.Count(m => m == "response")); // one tool-call response per CallToolAsync + Assert.Equal(1, observedMessages.Count(m => m == $"request:{RequestMethods.SamplingCreateMessage}")); + + // Preserve the original ordering intent: initialize first, then progress, then the final response. int initializeIndex = observedMessages.IndexOf("initialize"); int progressIndex = observedMessages.IndexOf("progress"); int responseIndex = observedMessages.LastIndexOf("response"); - int requestIndex = observedMessages.IndexOf($"request:{RequestMethods.SamplingCreateMessage}"); - Assert.True(initializeIndex >= 0); Assert.True(progressIndex > initializeIndex); Assert.True(responseIndex > progressIndex); - Assert.True(requestIndex >= 0); } [Fact] @@ -516,7 +534,7 @@ public async Task AddIncomingMessageFilter_SkipNext_DoesNotLogSendingResponse() McpServerBuilder .WithMessageFilters(filters => filters.AddIncomingFilter((next) => (context, cancellationToken) => { - // Skip processing tools/list requests — handler never runs, no response sent + // Skip processing tools/list requests - handler never runs, no response sent if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) { return Task.CompletedTask; @@ -552,7 +570,7 @@ public async Task AddIncomingMessageFilter_CallsNext_LogsSendingResponse() McpServerBuilder .WithMessageFilters(filters => filters.AddIncomingFilter((next) => (context, cancellationToken) => { - // Pass through — handler runs, response is sent + // Pass through - handler runs, response is sent return next(context, cancellationToken); })) .WithTools(); diff --git a/tests/ModelContextProtocol.Tests/Protocol/MrtrSerializationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/MrtrSerializationTests.cs new file mode 100644 index 000000000..e44f6527c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/MrtrSerializationTests.cs @@ -0,0 +1,298 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Protocol; + +public static class MrtrSerializationTests +{ + [Fact] + public static void IncompleteResult_SerializationRoundTrip_PreservesAllProperties() + { + var original = new InputRequiredResult + { + InputRequests = new Dictionary + { + ["input_1"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new() + }), + ["input_2"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] }], + MaxTokens = 100 + }) + }, + RequestState = "correlation-123", + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("input_required", deserialized.ResultType); + Assert.Equal("correlation-123", deserialized.RequestState); + Assert.NotNull(deserialized.InputRequests); + Assert.Equal(2, deserialized.InputRequests.Count); + Assert.True(deserialized.InputRequests.ContainsKey("input_1")); + Assert.True(deserialized.InputRequests.ContainsKey("input_2")); + } + + [Fact] + public static void IncompleteResult_HasResultTypeIncomplete() + { + var result = new InputRequiredResult(); + Assert.Equal("input_required", result.ResultType); + } + + [Fact] + public static void IncompleteResult_ResultType_AppearsInJson() + { + var result = new InputRequiredResult + { + RequestState = "abc", + }; + + string json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("input_required", (string?)node["resultType"]); + Assert.Equal("abc", (string?)node["requestState"]); + } + + [Fact] + public static void InputRequest_ForElicitation_SerializesCorrectly() + { + var inputRequest = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Enter name", + RequestedSchema = new() + }); + + string json = JsonSerializer.Serialize(inputRequest, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("elicitation/create", (string?)node["method"]); + Assert.NotNull(node["params"]); + Assert.Equal("Enter name", (string?)node["params"]!["message"]); + } + + [Fact] + public static void InputRequest_ForSampling_SerializesCorrectly() + { + var inputRequest = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Prompt" }] }], + MaxTokens = 50 + }); + + string json = JsonSerializer.Serialize(inputRequest, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("sampling/createMessage", (string?)node["method"]); + Assert.NotNull(node["params"]); + Assert.Equal(50, (int?)node["params"]!["maxTokens"]); + } + + [Fact] + public static void InputRequest_ForRootsList_SerializesCorrectly() + { + var inputRequest = InputRequest.ForRootsList(new ListRootsRequestParams()); + + string json = JsonSerializer.Serialize(inputRequest, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("roots/list", (string?)node["method"]); + } + + [Fact] + public static void InputRequest_Elicitation_RoundTrip() + { + var original = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "test message", + RequestedSchema = new() + }); + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("elicitation/create", deserialized.Method); + Assert.NotNull(deserialized.ElicitationParams); + Assert.Equal("test message", deserialized.ElicitationParams.Message); + } + + [Fact] + public static void InputRequest_Sampling_RoundTrip() + { + var original = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] }], + MaxTokens = 200 + }); + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("sampling/createMessage", deserialized.Method); + Assert.NotNull(deserialized.SamplingParams); + Assert.Equal(200, deserialized.SamplingParams.MaxTokens); + } + + [Fact] + public static void InputRequest_RootsList_RoundTrip() + { + var original = InputRequest.ForRootsList(new ListRootsRequestParams()); + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("roots/list", deserialized.Method); + Assert.NotNull(deserialized.RootsParams); + } + + [Fact] + public static void InputResponse_FromSamplingResult_RoundTrip() + { + var samplingResult = new CreateMessageResult + { + Content = [new TextContentBlock { Text = "Response text" }], + Model = "test-model" + }; + + var inputResponse = InputResponse.FromSamplingResult(samplingResult); + + // Serialize → deserialize + string json = JsonSerializer.Serialize(inputResponse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + var sampling = deserialized.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo); + Assert.NotNull(sampling); + Assert.Equal("test-model", sampling.Model); + } + + [Fact] + public static void InputResponse_FromElicitResult_RoundTrip() + { + var elicitResult = new ElicitResult + { + Action = "confirm", + Content = new Dictionary + { + ["key"] = JsonDocument.Parse("\"value\"").RootElement.Clone() + } + }; + + var inputResponse = InputResponse.FromElicitResult(elicitResult); + + string json = JsonSerializer.Serialize(inputResponse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + var elicit = deserialized.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + Assert.NotNull(elicit); + Assert.Equal("confirm", elicit.Action); + } + + [Fact] + public static void InputResponse_FromRootsResult_RoundTrip() + { + var rootsResult = new ListRootsResult + { + Roots = [new Root { Uri = "file:///test", Name = "Test" }] + }; + + var inputResponse = InputResponse.FromRootsResult(rootsResult); + + string json = JsonSerializer.Serialize(inputResponse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + var roots = deserialized.Deserialize(InputResponse.ListRootsResultJsonTypeInfo); + Assert.NotNull(roots); + Assert.Single(roots.Roots); + Assert.Equal("file:///test", roots.Roots[0].Uri); + } + + [Fact] + public static void InputRequestDictionary_SerializationRoundTrip() + { + IDictionary requests = new Dictionary + { + ["a"] = InputRequest.ForElicitation(new ElicitRequestParams { Message = "q1", RequestedSchema = new() }), + ["b"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "q2" }] }], + MaxTokens = 50 + }), + }; + + string json = JsonSerializer.Serialize(requests, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize>(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Count); + Assert.Equal("elicitation/create", deserialized["a"].Method); + Assert.Equal("sampling/createMessage", deserialized["b"].Method); + } + + [Fact] + public static void InputResponseDictionary_SerializationRoundTrip() + { + IDictionary responses = new Dictionary + { + ["a"] = InputResponse.FromElicitResult(new ElicitResult { Action = "confirm" }), + ["b"] = InputResponse.FromSamplingResult(new CreateMessageResult + { + Content = [new TextContentBlock { Text = "AI" }], + Model = "m1" + }), + }; + + string json = JsonSerializer.Serialize(responses, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize>(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Count); + } + + [Fact] + public static void Result_ResultType_DefaultsToNull() + { + var result = new CallToolResult + { + Content = [new TextContentBlock { Text = "test" }] + }; + + string json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + // result_type should not appear for normal results + Assert.Null(node?["resultType"]); + } + + [Fact] + public static void RequestParams_InputResponses_NotSerializedByDefault() + { + var callParams = new CallToolRequestParams + { + Name = "test-tool", + }; + + string json = JsonSerializer.Serialize(callParams, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + // inputResponses and requestState should not appear when null + Assert.Null(node?["inputResponses"]); + Assert.Null(node?["requestState"]); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/DraftProtocolBackcompatTests.cs b/tests/ModelContextProtocol.Tests/Server/DraftProtocolBackcompatTests.cs new file mode 100644 index 000000000..662ffdb27 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/DraftProtocolBackcompatTests.cs @@ -0,0 +1,151 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Verifies that the server-to-client request methods (, +/// , +/// ) keep working when the negotiated protocol revision is +/// DRAFT-2026-v1 on a stateful session - for example, stdio. +/// +/// +/// Under DRAFT-2026-v1 the spec removes the corresponding server-to-client request methods, but +/// the SDK only fails fast in stateless mode (where the existing ThrowIf*Unsupported guards already +/// throw "X is not supported in stateless mode" because is +/// ). Stdio is implicitly stateful - one per process - so the +/// legacy elicitation/create / sampling/createMessage / roots/list flow still works. +/// A future PR is expected to force DRAFT-2026-v1 Streamable HTTP servers to stateless mode, at which +/// point those configurations will start throwing through the existing stateless guard. +/// +public sealed class DraftProtocolBackcompatTests : ClientServerTestBase +{ + public DraftProtocolBackcompatTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create(ElicitToolAsync, new() { Name = "elicit-tool" }), + McpServerTool.Create(SampleToolAsync, new() { Name = "sample-tool" }), + McpServerTool.Create(RootsToolAsync, new() { Name = "roots-tool" }), + ]); + } + + [Fact] + public async Task ElicitAsync_OnStatefulDraftSession_ResolvesViaLegacyRequest() + { + StartServer(); + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ClientCapabilities + { + Elicitation = new ElicitationCapability(), + }, + Handlers = new McpClientHandlers + { + ElicitationHandler = (_, _) => new ValueTask(new ElicitResult { Action = "accept" }), + }, + }); + + var result = await client.CallToolAsync("elicit-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("elicit-ok:accept", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task SampleAsync_OnStatefulDraftSession_ResolvesViaLegacyRequest() + { + StartServer(); + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ClientCapabilities + { + Sampling = new SamplingCapability(), + }, + Handlers = new McpClientHandlers + { + SamplingHandler = (_, _, _) => new ValueTask(new CreateMessageResult + { + Model = "test-model", + Role = Role.Assistant, + Content = [new TextContentBlock { Text = "hello back" }], + }), + }, + }); + + var result = await client.CallToolAsync("sample-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("sample-ok:hello back", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task RequestRootsAsync_OnStatefulDraftSession_ResolvesViaLegacyRequest() + { + StartServer(); + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ClientCapabilities + { + Roots = new RootsCapability(), + }, + Handlers = new McpClientHandlers + { + RootsHandler = (_, _) => new ValueTask(new ListRootsResult + { + Roots = [new Root { Uri = "file:///home", Name = "home" }], + }), + }, + }); + + var result = await client.CallToolAsync("roots-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("roots-ok:file:///home", Assert.IsType(result.Content[0]).Text); + } + + private static async Task ElicitToolAsync(McpServer server, CancellationToken cancellationToken) + { + var elicit = await server.ElicitAsync(new ElicitRequestParams + { + Message = "Need input", + RequestedSchema = new(), + }, cancellationToken); + return $"elicit-ok:{elicit.Action}"; + } + + private static async Task SampleToolAsync(McpServer server, CancellationToken cancellationToken) + { + var sample = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "ping" }], + }, + ], + MaxTokens = 16, + }, cancellationToken); + var text = sample.Content.OfType().FirstOrDefault()?.Text; + return $"sample-ok:{text}"; + } + + private static async Task RootsToolAsync(McpServer server, CancellationToken cancellationToken) + { + var roots = await server.RequestRootsAsync(new ListRootsRequestParams(), cancellationToken); + return $"roots-ok:{roots.Roots.FirstOrDefault()?.Uri}"; + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MrtrHandlerLifecycleTests.cs b/tests/ModelContextProtocol.Tests/Server/MrtrHandlerLifecycleTests.cs new file mode 100644 index 000000000..9a408ce78 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MrtrHandlerLifecycleTests.cs @@ -0,0 +1,438 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for the server's MRTR handler lifecycle management - cancellation, disposal, and error +/// logging during multi round-trip request processing. +/// +public class MrtrHandlerLifecycleTests : ClientServerTestBase +{ + private readonly TaskCompletionSource _handlerTokenCancelled = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _handlerStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _handlerResumed = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _releaseHandler = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly ServerMessageTracker _messageTracker = new(); + + public MrtrHandlerLifecycleTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.AddLogging(builder => builder.SetMinimumLevel(LogLevel.Debug)); + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + _messageTracker.AddFilters(options.Filters.Message); + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + return $"{result.Action}:{result.Content?.FirstOrDefault().Value}"; + }, + new McpServerToolCreateOptions + { + Name = "elicitation-tool", + Description = "A tool that requests elicitation from the client" + }), + McpServerTool.Create( + async (McpServer server, CancellationToken ct) => + { + var handlerTokenCancelled = _handlerTokenCancelled; + ct.Register(static state => ((TaskCompletionSource)state!).TrySetResult(true), handlerTokenCancelled); + _handlerStarted.TrySetResult(true); + + await server.ElicitAsync(new ElicitRequestParams + { + Message = "Cancellation test", + RequestedSchema = new() + }, ct); + + return "done"; + }, + new McpServerToolCreateOptions + { + Name = "cancellation-test-tool", + Description = "A tool that monitors its CancellationToken during MRTR" + }), + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + // Elicit first, then block forever - the retry request stays in-flight + // until the client cancels, verifying that notifications/cancelled for + // the retry's request ID flows through to cancel this handler. + _handlerStarted.TrySetResult(true); + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + // Signal that we resumed after ElicitAsync, then block. + _handlerResumed.TrySetResult(true); + await Task.Delay(Timeout.Infinite, ct); + return "unreachable"; + }, + new McpServerToolCreateOptions + { + Name = "elicit-then-block-tool", + Description = "A tool that elicits then blocks forever for cancellation testing" + }), + McpServerTool.Create( + async (McpServer server, CancellationToken ct) => + { + // Two sequential MRTR rounds. The client will inject a stale cancellation + // notification for the original request ID between round 1 and round 2. + var r1 = await server.ElicitAsync(new ElicitRequestParams + { + Message = "First elicitation", + RequestedSchema = new() + }, ct); + + // Signal that round 1 completed so the test can inject the stale notification. + _handlerResumed.TrySetResult(true); + + var r2 = await server.ElicitAsync(new ElicitRequestParams + { + Message = "Second elicitation", + RequestedSchema = new() + }, ct); + + return $"{r1.Action},{r2.Action}"; + }, + new McpServerToolCreateOptions + { + Name = "double-elicit-tool", + Description = "A tool that elicits twice for stale cancellation testing" + }), + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + // Elicit, resume, then wait on _releaseHandler for the dispose test. + _handlerStarted.TrySetResult(true); + await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + _handlerResumed.TrySetResult(true); + await _releaseHandler.Task; + return "handler-completed"; + }, + new McpServerToolCreateOptions + { + Name = "dispose-wait-tool", + Description = "A tool that elicits, resumes, then waits on a signal for disposal testing" + }), + McpServerTool.Create( + async (McpServer server, CancellationToken ct) => + { + await server.ElicitAsync(new ElicitRequestParams + { + Message = "elicit-then-throw", + RequestedSchema = new() + }, ct); + + throw new InvalidOperationException("Deliberate MRTR handler error for testing"); + }, + new McpServerToolCreateOptions + { + Name = "elicit-then-throw-tool", + Description = "A tool that elicits then throws an exception for error logging testing" + }), + McpServerTool.Create( + (McpServer server) => + { + // Low-level MRTR: throw InputRequiredException directly instead of using ElicitAsync. + // This should NOT be logged at Error level - it's normal MRTR control flow. + throw new InputRequiredException(new InputRequiredResult + { + InputRequests = new Dictionary + { + ["input_1"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "low-level elicit", + RequestedSchema = new() + }) + } + }); + }, + new McpServerToolCreateOptions + { + Name = "incomplete-result-tool", + Description = "A tool that throws InputRequiredException for low-level MRTR" + }) + ]); + } + + [Fact] + public async Task CallToolAsync_CancellationDuringMrtrRetry_ThrowsOperationCanceled() + { + // Verify that cancelling the CancellationToken during the MRTR retry loop + // (specifically during the elicitation handler callback) stops the loop. + StartServer(); + var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + { + // Cancel the token during the callback. The retry loop will throw + // OperationCanceledException on the next await after this handler returns. + cts.Cancel(); + return new ValueTask(new ElicitResult { Action = "accept" }); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + + await Assert.ThrowsAsync(async () => + await client.CallToolAsync("elicitation-tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: cts.Token)); + + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task ServerDisposal_CancelsHandlerCancellationToken_DuringMrtr() + { + // Verify that disposing the server cancels the handler's own CancellationToken + // (the `ct` parameter), not just the exchange ResponseTcs. Before the HandlerCts fix, + // the handler's CT was from a disposed CTS and could never be triggered. + StartServer(); + var elicitHandlerCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = async (request, ct) => + { + // Signal that the MRTR round trip reached the client, then block indefinitely. + elicitHandlerCalled.TrySetResult(true); + await Task.Delay(Timeout.Infinite, ct); + throw new OperationCanceledException(ct); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + + // Start the tool call in the background. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(30)); + var callTask = client.CallToolAsync("cancellation-test-tool", cancellationToken: cts.Token).AsTask(); + + // Wait for the handler to start on the server. + await _handlerStarted.Task.WaitAsync(TimeSpan.FromSeconds(30), TestContext.Current.CancellationToken); + + // Wait for the MRTR round trip to reach the client's elicitation handler. + await elicitHandlerCalled.Task.WaitAsync(TimeSpan.FromSeconds(30), TestContext.Current.CancellationToken); + + // Dispose the server - HandlerCts.Cancel() should trigger the handler's CancellationToken. + await Server.DisposeAsync(); + + // Verify the handler's CancellationToken was actually cancelled via HandlerCts, + // not just the exchange ResponseTcs.TrySetCanceled(). + await _handlerTokenCancelled.Task.WaitAsync(TimeSpan.FromSeconds(30), TestContext.Current.CancellationToken); + + // The client call should fail (server disposed mid-MRTR). + await Assert.ThrowsAnyAsync(async () => await callTask); + } + + [Fact] + public async Task CancellationNotification_DuringInFlightMrtrRetry_CancelsHandler() + { + // Verify that cancelling the client's CancellationToken while a retry request is in-flight + // sends notifications/cancelled with the retry's request ID, and the server correctly + // routes it to cancel the handler. This proves end-to-end that: + // (a) the client sends the notification with the CURRENT request ID (not the original), + // (b) the server's _handlingRequests lookup finds the retry's CTS, + // (c) the cancellation registration in AwaitMrtrHandlerAsync bridges to handlerCts. + StartServer(); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(30)); + var callTask = client.CallToolAsync( + "elicit-then-block-tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: cts.Token).AsTask(); + + // Wait for the handler to resume after ElicitAsync - at this point the retry + // request is in-flight (server is awaiting WhenAny in AwaitMrtrHandlerAsync). + await _handlerResumed.Task.WaitAsync(TimeSpan.FromSeconds(30), TestContext.Current.CancellationToken); + + // Cancel the client's token. The client is inside _sessionHandler.SendRequestAsync + // awaiting the retry response. RegisterCancellation fires and sends + // notifications/cancelled with the retry's request ID. + cts.Cancel(); + + // The call should throw OperationCanceledException. + await Assert.ThrowsAnyAsync(async () => await callTask); + + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task CancellationNotification_ForExpiredRequestId_DoesNotAffectHandler() + { + // Verify that a stale cancellation notification for the original (now-completed) + // request ID does not interfere with an active MRTR handler. The original request's + // entry was removed from _handlingRequests when it returned InputRequiredResult, so + // the notification should be a no-op. + StartServer(); + + int elicitationCount = 0; + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + { + Interlocked.Increment(ref elicitationCount); + return new ValueTask(new ElicitResult { Action = "accept" }); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + + // Start the double-elicit tool. Between round 1 and round 2, we'll inject a stale + // cancellation notification for a fake (expired) request ID. + var callTask = client.CallToolAsync( + "double-elicit-tool", + cancellationToken: TestContext.Current.CancellationToken).AsTask(); + + // Wait for handler to resume after the first ElicitAsync. + await _handlerResumed.Task.WaitAsync(TimeSpan.FromSeconds(30), TestContext.Current.CancellationToken); + + // Send a stale cancellation notification for a non-existent request ID. + // This simulates a delayed notification for the original request that already completed. + await client.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.CancelledNotification, + Params = JsonSerializer.SerializeToNode( + new CancelledNotificationParams { RequestId = new RequestId("stale-id-999"), Reason = "stale test" }, + McpJsonUtilities.DefaultOptions), + }, TestContext.Current.CancellationToken); + + // The tool should complete successfully - the stale notification didn't affect it. + var result = await callTask; + Assert.Contains("accept", result.Content.OfType().First().Text); + + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task DisposeAsync_WaitsForMrtrHandler_BeforeReturning() + { + // Verify that McpServer.DisposeAsync() waits for an MRTR handler to complete + // before returning, similar to RunAsync_WaitsForInFlightHandlersBeforeReturning + // which tests the same invariant for regular request handlers in McpSessionHandler. + StartServer(); + bool handlerCompleted = false; + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // Start the tool call that calls ElicitAsync, then blocks on _releaseHandler. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(30)); + _ = client.CallToolAsync( + "dispose-wait-tool", + new Dictionary { ["message"] = "dispose-wait-test" }, + cancellationToken: cts.Token); + + // Wait for the handler to resume after ElicitAsync - it's now blocking on _releaseHandler. + await _handlerResumed.Task.WaitAsync(TimeSpan.FromSeconds(30), TestContext.Current.CancellationToken); + + // Dispose the server. The handler is still running (blocked on _releaseHandler). + // Release the handler after a delay - DisposeAsync must wait for it. + var ct = TestContext.Current.CancellationToken; + _ = Task.Run(async () => + { + await Task.Delay(200, ct); + handlerCompleted = true; + _releaseHandler.SetResult(true); + }, ct); + + await Server.DisposeAsync(); + + // DisposeAsync should not have returned until the handler completed. + Assert.True(handlerCompleted, "DisposeAsync should wait for MRTR handlers to complete before returning."); + + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task HandlerException_DuringMrtr_IsLoggedAtErrorLevel() + { + // Verify that when a tool handler throws an unhandled exception during MRTR + // (after resuming from ElicitAsync), the error is logged at Error level. + StartServer(); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // Call the tool that elicits then throws. The retry returns an error result. + var result = await client.CallToolAsync( + "elicit-then-throw-tool", + cancellationToken: TestContext.Current.CancellationToken); + Assert.True(result.IsError); + + // Verify the tool error was logged at Error level during the MRTR retry. + // The ToolsCall handler catches the exception, logs it via ToolCallError, + // and converts it to an error result - so the error is properly surfaced. + Assert.Contains(MockLoggerProvider.LogMessages, m => + m.LogLevel == LogLevel.Error && + m.Message.Contains("elicit-then-throw-tool") && + m.Exception is InvalidOperationException); + + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task IncompleteResultException_IsNotLoggedAtErrorLevel() + { + // InputRequiredException is normal MRTR control flow (low-level API), + // not an error. It should not be logged via ToolCallError at Error level. + StartServer(); + + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // The tool always throws InputRequiredException (low-level MRTR path), + // so the client will retry until hitting the max retry limit. + await Assert.ThrowsAsync(() => client.CallToolAsync( + "incomplete-result-tool", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.DoesNotContain(MockLoggerProvider.LogMessages, m => + m.LogLevel == LogLevel.Error && + m.Exception is InputRequiredException); + + _messageTracker.AssertMrtrUsed(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MrtrInputRequiredExceptionTests.cs b/tests/ModelContextProtocol.Tests/Server/MrtrInputRequiredExceptionTests.cs new file mode 100644 index 000000000..664429b13 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MrtrInputRequiredExceptionTests.cs @@ -0,0 +1,61 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for the MRTR server API - IsMrtrSupported, InputRequiredException, +/// and client auto-retry of incomplete results. +/// +public class MrtrInputRequiredExceptionTests : ClientServerTestBase +{ + private readonly ServerMessageTracker _messageTracker = new(); + + public MrtrInputRequiredExceptionTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + _messageTracker.AddFilters(options.Filters.Message); + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create( + static string (McpServer server) => + { + throw new InputRequiredException(requestState: "should-not-work"); + }, + new McpServerToolCreateOptions + { + Name = "always-incomplete", + Description = "Tool that always throws InputRequiredException" + }), + ]); + } + + [Fact] + public async Task InputRequiredException_WithoutInputRequests_ExhaustsRetries() + { + StartServer(); + var clientOptions = new McpClientOptions(); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // The always-incomplete tool throws InputRequiredException with only requestState + // and no inputRequests. The client has nothing to dispatch, so it keeps retrying + // with the same requestState until the retry budget is exhausted. + var exception = await Assert.ThrowsAsync(() => + client.CallToolAsync("always-incomplete", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("more than", exception.Message); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MrtrMessageFilterTests.cs b/tests/ModelContextProtocol.Tests/Server/MrtrMessageFilterTests.cs new file mode 100644 index 000000000..fd9098734 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MrtrMessageFilterTests.cs @@ -0,0 +1,149 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests that message filters correctly observe MRTR protocol behavior - verifying that +/// InputRequiredResult responses are visible to outgoing filters, and that no legacy +/// elicitation/sampling requests are sent when MRTR is active. +/// +public class MrtrMessageFilterTests : ClientServerTestBase +{ + private readonly ServerMessageTracker _messageTracker = new(); + + public MrtrMessageFilterTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + _messageTracker.AddFilters(options.Filters.Message); + }); + + mcpServerBuilder + .WithTools([ + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + return $"{result.Action}"; + }, + new McpServerToolCreateOptions + { + Name = "elicit-tool", + Description = "A tool that requests elicitation" + }), + McpServerTool.Create( + async (string prompt, McpServer server, CancellationToken ct) => + { + var result = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = prompt }] }], + MaxTokens = 100 + }, ct); + + return result.Content.OfType().FirstOrDefault()?.Text ?? ""; + }, + new McpServerToolCreateOptions + { + Name = "sample-tool", + Description = "A tool that requests sampling" + }), + ]); + } + + [Fact] + public async Task MrtrActive_NoOldStyleElicitationRequests_SentOverWire() + { + // When both sides are on the experimental protocol, the server should use MRTR + // (InputRequiredResult) instead of sending old-style elicitation/create JSON-RPC requests. + StartServer(); + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + { + return new ValueTask(new ElicitResult { Action = "accept" }); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("elicit-tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(result.Content); + Assert.Equal("accept", Assert.IsType(content).Text); + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task MrtrActive_NoOldStyleSamplingRequests_SentOverWire() + { + StartServer(); + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.SamplingHandler = (request, progress, ct) => + { + var text = request?.Messages[^1].Content.OfType().FirstOrDefault()?.Text; + return new ValueTask(new CreateMessageResult + { + Content = [new TextContentBlock { Text = $"Sampled: {text}" }], + Model = "test-model" + }); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("sample-tool", + new Dictionary { ["prompt"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(result.Content); + Assert.Equal("Sampled: test", Assert.IsType(content).Text); + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task OutgoingFilter_SeesIncompleteResultResponse() + { + // Verify that transport middleware can observe the raw InputRequiredResult + // in outgoing JSON-RPC responses (validates MRTR transport visibility). + var sawIncompleteResult = false; + + StartServer(); + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + { + // If we reach this handler, it means the client received an InputRequiredResult + // from the server, resolved the elicitation, and is retrying. + sawIncompleteResult = true; + return new ValueTask(new ElicitResult { Action = "accept" }); + }; + + await using var client = await CreateMcpClientForServer(clientOptions); + + await client.CallToolAsync("elicit-tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // The elicitation handler was called, confirming MRTR round-trip occurred + // (InputRequiredResult was sent by server and processed by client). + Assert.True(sawIncompleteResult, "Expected MRTR round-trip with InputRequiredResult"); + _messageTracker.AssertMrtrUsed(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MrtrServerBackcompatTests.cs b/tests/ModelContextProtocol.Tests/Server/MrtrServerBackcompatTests.cs new file mode 100644 index 000000000..d8fa6f32b --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MrtrServerBackcompatTests.cs @@ -0,0 +1,113 @@ +using System.Text.Json; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for the legacy MRTR backcompat resolver in McpServerImpl.InvokeWithInputRequiredResultHandlingAsync. +/// This path runs only when the client did NOT negotiate MRTR (DRAFT-2026-v1) and the session is stateful - +/// the server dispatches each input request to the client via standard JSON-RPC and re-invokes the handler +/// with the merged responses. To exercise it the server must NOT pin a protocol version; the client picks +/// a non-draft version during initialize negotiation. +/// +public class MrtrServerBackcompatTests : ClientServerTestBase +{ + private readonly List _observedRequestStates = []; + private int _attempt; + + public MrtrServerBackcompatTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithTools([ + McpServerTool.Create( + (RequestContext context) => + { + var attempt = Interlocked.Increment(ref _attempt); + _observedRequestStates.Add(context.Params?.RequestState); + + return attempt switch + { + // Round 1: caller has no state; emit one and request elicitation. + 1 => throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "round1", + RequestedSchema = new() + }) + }, + requestState: "round1"), + // Round 2: deliberately clear the state by passing requestState: null while still + // asking for another elicitation. This exercises the params clone path that + // previously preserved the stale "round1" carry-over from round 1's deep clone. + 2 => throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "round2", + RequestedSchema = new() + }) + }, + requestState: null), + // Round 3 (final): report what the handler observed so the test can assert it. + _ => $"final-state:{context.Params?.RequestState ?? ""}", + }; + }, + new McpServerToolCreateOptions + { + Name = "requeststate-transition", + Description = "Tool that transitions requestState from set to null across MRTR rounds." + }), + ]); + } + + [Fact] + public async Task InputRequiredException_TransitioningRequestStateToNull_DoesNotLeakStaleState() + { + StartServer(); + + // Non-MRTR client → server falls into the legacy backcompat resolver path on InputRequiredException. + var clientOptions = new McpClientOptions + { + ProtocolVersion = "2025-06-18", + Capabilities = new ClientCapabilities { Elicitation = new() }, + }; + clientOptions.Handlers.ElicitationHandler = (_, _) => + new ValueTask(new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["answer"] = JsonDocument.Parse("\"ok\"").RootElement, + }, + }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + var result = await client.CallToolAsync( + "requeststate-transition", + cancellationToken: TestContext.Current.CancellationToken); + + // Three attempts: round 1 (no state) → round 2 (state="round1") → round 3 (state=null after fix). + // Without the fix, the third observed state would erroneously remain "round1" because the deep-clone + // of the prior request params carried it forward when InputRequiredException.RequestState was null. + Assert.Equal(3, _observedRequestStates.Count); + Assert.Null(_observedRequestStates[0]); + Assert.Equal("round1", _observedRequestStates[1]); + Assert.Null(_observedRequestStates[2]); + + var content = Assert.Single(result.Content); + var text = Assert.IsType(content).Text; + Assert.Equal("final-state:", text); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MrtrSessionLimitTests.cs b/tests/ModelContextProtocol.Tests/Server/MrtrSessionLimitTests.cs new file mode 100644 index 000000000..1836d4d13 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MrtrSessionLimitTests.cs @@ -0,0 +1,183 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Collections.Concurrent; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for session-scoped MRTR resource governance - verifying that outgoing message +/// filters can track and limit MRTR round trips per session. +/// +public class MrtrSessionLimitTests : ClientServerTestBase +{ + /// + /// Tracks the number of pending MRTR flows per session. Incremented when an InputRequiredResult + /// is sent (outgoing filter), decremented when a retry with requestState arrives (incoming filter). + /// + private readonly ConcurrentDictionary _pendingFlowsPerSession = new(); + + /// + /// Records every (sessionId, pendingCount) observation from the outgoing filter, + /// so the test can verify the tracking was correct. + /// + private readonly ConcurrentBag<(string SessionId, int PendingCount)> _observations = []; + + private readonly ServerMessageTracker _messageTracker = new(); + + /// + /// Maximum allowed concurrent MRTR flows per session. If exceeded, the outgoing filter + /// replaces the InputRequiredResult with an error response. + /// + private int _maxFlowsPerSession = int.MaxValue; + + /// + /// Counts how many IncompleteResults were blocked by the per-session limit. + /// + private int _blockedFlowCount; + + public MrtrSessionLimitTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + _messageTracker.AddFilters(options.Filters.Message); + + // Outgoing filter: detect InputRequiredResult responses and track per session. + options.Filters.Message.OutgoingFilters.Add(next => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcResponse response && + response.Result is JsonObject resultObj && + resultObj.TryGetPropertyValue("resultType", out var resultTypeNode) && + resultTypeNode?.GetValue() is "input_required") + { + var sessionId = context.Server.SessionId ?? "unknown"; + var newCount = _pendingFlowsPerSession.AddOrUpdate(sessionId, 1, (_, c) => c + 1); + _observations.Add((sessionId, newCount)); + + // Enforce per-session limit: if exceeded, replace the InputRequiredResult + // with a JSON-RPC error. This prevents the client from receiving the + // InputRequiredResult and starting another retry cycle. + if (newCount > _maxFlowsPerSession) + { + // Undo the increment since we're blocking this flow. + _pendingFlowsPerSession.AddOrUpdate(sessionId, 0, (_, c) => Math.Max(0, c - 1)); + Interlocked.Increment(ref _blockedFlowCount); + + // Replace the outgoing message with a JSON-RPC error. + context.JsonRpcMessage = new JsonRpcError + { + Id = response.Id, + Error = new JsonRpcErrorDetail + { + Code = (int)McpErrorCode.InvalidRequest, + Message = $"Too many pending MRTR flows for this session (limit: {_maxFlowsPerSession}).", + } + }; + } + } + + await next(context, cancellationToken); + }); + + // Incoming filter: detect retries (requests with requestState) and decrement. + options.Filters.Message.IncomingFilters.Add(next => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && + request.Params is JsonObject paramsObj && + paramsObj.TryGetPropertyValue("requestState", out var stateNode) && + stateNode is not null) + { + var sessionId = context.Server.SessionId ?? "unknown"; + _pendingFlowsPerSession.AddOrUpdate(sessionId, 0, (_, c) => Math.Max(0, c - 1)); + } + + await next(context, cancellationToken); + }); + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + return $"{result.Action}"; + }, + new McpServerToolCreateOptions + { + Name = "elicit-tool", + Description = "A tool that requests elicitation" + }), + ]); + } + + [Fact] + public async Task OutgoingFilter_TracksIncompleteResultsPerSession() + { + // Verify that an outgoing message filter can observe InputRequiredResult responses + // and track the pending MRTR flow count per session using context.Server.SessionId. + StartServer(); + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // Call the tool - triggers one MRTR round-trip. + var result = await client.CallToolAsync("elicit-tool", + new Dictionary { ["message"] = "confirm?" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("accept", Assert.IsType(Assert.Single(result.Content)).Text); + + // Verify the filter observed exactly one InputRequiredResult and tracked it. + Assert.Single(_observations); + var (sessionId, pendingCount) = _observations.First(); + Assert.NotNull(sessionId); + Assert.Equal(1, pendingCount); + + // After the retry completed, the count should be back to 0. + Assert.Equal(0, _pendingFlowsPerSession.GetValueOrDefault(sessionId)); + + _messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task OutgoingFilter_CanEnforcePerSessionMrtrLimit() + { + // Verify that an outgoing message filter can enforce a per-session MRTR flow limit + // by replacing the InputRequiredResult with a JSON-RPC error when the limit is exceeded. + // Set the limit to 0 so the very first MRTR flow is blocked. + _maxFlowsPerSession = 0; + + StartServer(); + var clientOptions = new McpClientOptions { ProtocolVersion = "DRAFT-2026-v1" }; + clientOptions.Handlers.ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // The tool call should fail because the outgoing filter blocks the InputRequiredResult. + var ex = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("elicit-tool", + new Dictionary { ["message"] = "confirm?" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains("Too many pending MRTR flows", ex.Message); + Assert.Equal(1, _blockedFlowCount); + } +}