diff --git a/docs/client/callbacks.md b/docs/client/callbacks.md index e9787da8d..6b4e934cf 100644 --- a/docs/client/callbacks.md +++ b/docs/client/callbacks.md @@ -78,6 +78,8 @@ When a client connects it declares its `capabilities`, the mirror image of the s | `list_roots_callback=` | `"roots": {"listChanged": true}` | | none of them | `{}` | +Sampling sub-capabilities are the one refinement: pass `sampling_capabilities=SamplingCapability(tools=SamplingToolsCapability())` alongside `sampling_callback` when your sampler handles the `tools` / `tool_choice` parameters. Servers must see `sampling.tools` declared before they can send them. + `logging_callback` and `message_handler` are not in the table. They handle notifications, and notifications need no capability. The server reads the declaration back with `ctx.session.check_client_capability(...)`. Add a tool that does: diff --git a/docs/handlers/dependencies.md b/docs/handlers/dependencies.md index 6260b72f5..509b2635f 100644 --- a/docs/handlers/dependencies.md +++ b/docs/handlers/dependencies.md @@ -134,6 +134,18 @@ That's the right default for a precondition: no answer, no order. When declining to bind to. A question built from such volatile data makes every recorded answer look stale, so the server re-asks it on every round until the client's round limit ends the call. +## Ask the client, not the user + +Elicitation is one of the three questions a resolver can ask, and the multi-round-trip flow allows no others. The other two go to the **client** rather than the user: return `Sample(...)` to run an LLM call through the client (a `sampling/createMessage` request), or `ListRoots()` to fetch the client's current roots. Neither has an accept/decline outcome; the consumer annotates the result type directly, `CreateMessageResult` (`CreateMessageResultWithTools` when the request carries `tools` or `tool_choice`) or `ListRootsResult`: + +```python title="server.py" hl_lines="11-16 22" +--8<-- "docs_src/dependencies/tutorial004.py" +``` + +* The framework routes these exactly like `Elicit`: inside the multi-round-trip `tools/call` on **2026-07-28**, over the standalone server->client request on **2025-11-25**. An undeclared capability refuses the call with a `-32021` protocol error (`sampling`, `roots`, form-mode `elicitation`; `sampling.tools` when the request carries `tools` or `tool_choice`). +* Everything the info box above says about questions applies unchanged: a `Sample` request is matched to its recorded result by its exact rendering, so build it deterministically from the tool's arguments and earlier answers; the client then pays for the LLM call once per tool call, not once per round. The recorded result rides `request_state` for the rest of the call, so a very large completion makes every remaining round-trip heavier. +* The standalone sampling and roots *features* are deprecated at 2026-07-28 (SEP-2577). New servers that need the client's model ask through this carrier; servers that don't should integrate with an LLM provider directly. `include_context` values other than `"none"` are themselves deprecated; avoid them. + ## Recap * `Annotated[T, Resolve(fn)]` on a tool parameter: the SDK runs `fn` and injects its return value. @@ -141,5 +153,6 @@ That's the right default for a precondition: no answer, no order. When declining * A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per round, however many consumers it has; each question is asked exactly once, and any resolver may run again when a call resumes after a question. * Bad graphs fail at registration with `InvalidSignature`, not mid-call. * Return `Elicit(message, Model)` to ask the user, only when you have to. Unwrapped annotations abort on decline; `ElicitationResult[T]` lets the tool branch. +* Return `Sample(...)` or `ListRoots()` to ask the client for an LLM completion or the roots list; the plain result is injected. The state your server builds once at startup, and how a handler reaches it, is the **[Lifespan](lifespan.md)** page. diff --git a/docs/handlers/index.md b/docs/handlers/index.md index eb2b5be41..daf9fde19 100644 --- a/docs/handlers/index.md +++ b/docs/handlers/index.md @@ -18,6 +18,9 @@ What it can do while it runs: * Ask the user for more input with **[Elicitation](elicitation.md)**, and **[Multi-round-trip requests](multi-round-trip.md)**, the 2026-07-28 pattern that carries it. +* Ask the client for an LLM completion or its workspace folders with + **[Sampling and roots](sampling-and-roots.md)**, deprecated but still + served. * Report **[Progress](progress.md)** on something slow. * Write logs (to standard error, for whoever operates the server) with **[Logging](logging.md)**. diff --git a/docs/handlers/multi-round-trip.md b/docs/handlers/multi-round-trip.md index d5451e231..e08903444 100644 --- a/docs/handlers/multi-round-trip.md +++ b/docs/handlers/multi-round-trip.md @@ -19,7 +19,7 @@ That's the whole protocol. Every leg is an ordinary request from the client to t ## The server side -On `@mcp.tool()` you rarely build this by hand: declare a dependency that asks the user and the SDK returns the `InputRequiredResult` for you - that form is the **[Dependencies](dependencies.md)** page. The two forms don't mix: a call has one `input_responses`/`request_state` channel, so a tool that uses `Resolve(...)` parameters cannot also return `InputRequiredResult` from its body. A declared `InputRequiredResult` return is rejected at registration (`InvalidSignature`), and an undeclared one fails the call at runtime. The manual form is the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: +On `@mcp.tool()` you rarely build this by hand: declare a dependency that asks the user (`Elicit`), samples the client's LLM (`Sample`), or lists its roots (`ListRoots`) and the SDK returns the `InputRequiredResult` for you; that form is the **[Dependencies](dependencies.md)** page. The two forms don't mix: a call has one `input_responses`/`request_state` channel, so a tool that uses `Resolve(...)` parameters cannot also return `InputRequiredResult` from its body. A declared `InputRequiredResult` return is rejected at registration (`InvalidSignature`), and an undeclared one fails the call at runtime. The manual form is the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: ```python title="server.py" hl_lines="44-47" --8<-- "docs_src/mrtr/tutorial001.py" diff --git a/docs/handlers/sampling-and-roots.md b/docs/handlers/sampling-and-roots.md new file mode 100644 index 000000000..6174f4258 --- /dev/null +++ b/docs/handlers/sampling-and-roots.md @@ -0,0 +1,46 @@ +# Sampling and roots + +A handler can ask the connected client for two more things: a completion from the client's own model (**sampling**), and the client's workspace folders (**roots**). + +Both still work, on every protocol version the SDK speaks. But read the warning before you design around them: + +!!! warning "Deprecated by the 2026-07-28 specification" + Sampling and roots are deprecated as of `2026-07-28` ([SEP-2577](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/2577)). They remain fully functional and stay in the specification for at least twelve months before becoming eligible for removal, but new implementations should not build on them. The suggested migrations: integrate directly with your LLM provider's API instead of sampling, and pass directories via tool parameters, resource URIs, or server configuration instead of roots. The SDK-wide list is in **[Deprecated features](../deprecated.md)**. + +## Sampling: borrow the client's model + +A resolver returns `Sample(...)` and the tool receives the completion, through the same dependency mechanism that runs `Elicit` in **[Dependencies](dependencies.md)**: + +```python title="server.py" hl_lines="11-16 20" +--8<-- "docs_src/sampling_and_roots/tutorial001.py" +``` + +* `Sample(messages, max_tokens=...)` mirrors the `sampling/createMessage` parameters. The injected value is the client's `CreateMessageResult`; pass `tools` or `tool_choice` and it becomes a `CreateMessageResultWithTools` instead. +* The client must have declared the `sampling` capability (`sampling.tools` if you pass `tools` or `tool_choice`). If it didn't, the call fails with a `-32021` protocol error instead of sending a request the client cannot handle. A pre-2026 session with no back-channel fails with its usual no-back-channel error, since there is nothing to send on. +* At `2026-07-28` the request is delivered inside the multi-round-trip flow (**[Multi-round-trip requests](multi-round-trip.md)**); on `2025-11-25` it is a standalone request to the client. The code is the same either way, but mind the multi-round-trip rule: the request must render identically across retry rounds, so build it only from the tool's arguments and other stable data. +* Leave `include_context` alone: values other than `"none"` are themselves deprecated (SEP-2596) and need a capability almost no client declares. + +## Roots: where should this go? + +Roots are the folders the client says the server may operate on. They are informational guidance, not an access-control mechanism. A resolver returns `ListRoots()`: + +```python title="server.py" hl_lines="11-12 16" +--8<-- "docs_src/sampling_and_roots/tutorial002.py" +``` + +* The injected `ListRootsResult` carries a list of `Root`s: a `file://` URI and an optional display name. +* The gate is the same as for sampling: without a declared `roots` capability the call fails with `-32021` instead of sending the request. + +On the other side of the wire, the client answers both requests with the callbacks it already has: `sampling_callback` and `list_roots_callback`, covered in **[Client callbacks](../client/callbacks.md)**. + +## On 2025-era connections + +`ctx.session.create_message(...)` and `ctx.session.list_roots()` still exist for code that drives the session directly. They only work where a back-channel exists (2025-era, non-stateless connections), and calling them raises a deprecation warning. The resolver markers above are the supported form: they pick the delivery from the negotiated version and don't warn. + +## Recap + +* Return `Sample(...)` or `ListRoots()` from a resolver; the tool receives the `CreateMessageResult` or `ListRootsResult` like any other dependency. +* The client must declare the matching capability, or the call fails with `-32021` instead of a request being sent. +* Both features are deprecated at `2026-07-28`: fully functional for now, wrong for new designs. Prefer provider APIs over sampling and explicit parameters over roots. + +Reporting how far along a slow tool is: **[Progress](progress.md)**. diff --git a/docs/migration.md b/docs/migration.md index 3c544d00e..811fa17d9 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -697,6 +697,24 @@ and raises `RuntimeError` if the resource requests input. The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. +### Resolver-routed requests require the client capability on every protocol version + +A v1 server could send elicitation, sampling, and roots requests to clients +that never declared the matching capability; only tools-bearing sampling was +checked. In v2 the `Resolve(...)` markers (`Elicit`, `Sample`, `ListRoots`) +enforce the spec's egress rule: an undeclared capability (form-mode `elicitation`, +`sampling`, or `roots`, plus `sampling.tools` when the request carries `tools` +or `tool_choice`) fails the call with a `-32021` +`MISSING_REQUIRED_CLIENT_CAPABILITY` JSON-RPC error instead of sending a +request the client cannot handle. This applies on 2025-11-25 sessions with a +live back-channel too; a session with no back-channel keeps failing with its +no-back-channel error. To migrate, declare the capability: the SDK client +declares `elicitation`, `sampling`, and `roots` when the matching callback is +set, and `sampling.tools` needs an explicit +`Client(sampling_capabilities=SamplingCapability(tools=...))`. Direct +`ctx.elicit()` and `ctx.session.*` calls outside resolvers keep their previous +behavior, including the pre-existing tools check on `create_message`. + ### `MCPError` raised from an `@mcp.tool()` handler now surfaces as a JSON-RPC error Raising `MCPError` (or any subclass) inside an `@mcp.tool()` handler now diff --git a/docs_src/dependencies/tutorial004.py b/docs_src/dependencies/tutorial004.py new file mode 100644 index 000000000..ff55e5ce1 --- /dev/null +++ b/docs_src/dependencies/tutorial004.py @@ -0,0 +1,26 @@ +from typing import Annotated + +from mcp_types import CreateMessageResult, SamplingMessage, TextContent + +from mcp.server import MCPServer +from mcp.server.mcpserver import Resolve, Sample + +mcp = MCPServer("Bookshop") + + +def suggest_title(genre: str) -> Sample: + prompt = f"Suggest one {genre} book title. Answer with the title only." + return Sample( + [SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], + max_tokens=50, + ) + + +@mcp.tool() +async def recommend_book( + genre: str, + suggestion: Annotated[CreateMessageResult, Resolve(suggest_title)], +) -> str: + """Recommend a book in the given genre.""" + title = suggestion.content.text if suggestion.content.type == "text" else "the classics" + return f"Today's {genre} pick: {title}" diff --git a/docs_src/sampling_and_roots/__init__.py b/docs_src/sampling_and_roots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs_src/sampling_and_roots/tutorial001.py b/docs_src/sampling_and_roots/tutorial001.py new file mode 100644 index 000000000..c1e041c32 --- /dev/null +++ b/docs_src/sampling_and_roots/tutorial001.py @@ -0,0 +1,22 @@ +from typing import Annotated + +from mcp_types import CreateMessageResult, SamplingMessage, TextContent + +from mcp.server import MCPServer +from mcp.server.mcpserver import Resolve, Sample + +mcp = MCPServer("Bookshop") + + +def draft_blurb(title: str) -> Sample: + prompt = f"Write a one-sentence blurb for the book {title!r}." + return Sample( + [SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], + max_tokens=60, + ) + + +@mcp.tool() +async def blurb(title: str, draft: Annotated[CreateMessageResult, Resolve(draft_blurb)]) -> str: + """Draft a blurb for a book.""" + return draft.content.text if draft.content.type == "text" else "No blurb." diff --git a/docs_src/sampling_and_roots/tutorial002.py b/docs_src/sampling_and_roots/tutorial002.py new file mode 100644 index 000000000..44a1d1057 --- /dev/null +++ b/docs_src/sampling_and_roots/tutorial002.py @@ -0,0 +1,20 @@ +from typing import Annotated + +from mcp_types import ListRootsResult + +from mcp.server import MCPServer +from mcp.server.mcpserver import ListRoots, Resolve + +mcp = MCPServer("Bookshop") + + +def workspace_roots() -> ListRoots: + return ListRoots() + + +@mcp.tool() +async def catalog_folder(roots: Annotated[ListRootsResult, Resolve(workspace_roots)]) -> str: + """Pick the folder the catalog export should go to.""" + if not roots.roots: + return "No workspace folders shared." + return str(roots.roots[0].uri) diff --git a/mkdocs.yml b/mkdocs.yml index 5f19b8982..ae0c57f3c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,6 +36,7 @@ nav: - Lifespan: handlers/lifespan.md - Elicitation: handlers/elicitation.md - Multi-round-trip requests: handlers/multi-round-trip.md + - Sampling and roots: handlers/sampling-and-roots.md - Progress: handlers/progress.md - Logging: handlers/logging.md - Subscriptions: handlers/subscriptions.md diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index d581fe6a5..fa78f15ea 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -303,6 +303,9 @@ async def main(): sampling_callback: SamplingFnT | None = None """Callback for handling sampling requests.""" + sampling_capabilities: types.SamplingCapability | None = None + """Sampling sub-capabilities (e.g. tools) declared alongside `sampling_callback`; no effect without it.""" + list_roots_callback: ListRootsFnT | None = None """Callback for handling list roots requests.""" @@ -418,6 +421,7 @@ async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: dispatcher=dispatcher, read_timeout_seconds=self.read_timeout_seconds, sampling_callback=self.sampling_callback, + sampling_capabilities=self.sampling_capabilities, list_roots_callback=self.list_roots_callback, logging_callback=self.logging_callback, message_handler=message_handler, diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 0205df192..56d1c23cb 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -19,7 +19,9 @@ DeclinedElicitation, Elicit, ElicitationResult, + ListRoots, Resolve, + Sample, ) from .resources import DEFAULT_RESOURCE_SECURITY, ResourceSecurity from .server import MCPServer, require_client_extension @@ -33,6 +35,8 @@ "Icon", "Resolve", "Elicit", + "Sample", + "ListRoots", "ElicitationResult", "AcceptedElicitation", "DeclinedElicitation", diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index d752afc10..d4a744af3 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -3,17 +3,14 @@ A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the resolver `fn` before the tool body, instead of from the LLM-supplied arguments. Resolvers form a DAG: a resolver may declare its own `Resolve(...)` dependencies, -take tool arguments by name, and take the `Context`. A resolver may return -`Elicit[T]` to ask the client; the framework runs the elicitation and injects the -answer. - -The framework picks the elicitation transport from the negotiated protocol. At ->= 2026-07-28 it returns an `InputRequiredResult` carrying the batched questions -and resumes when the client retries with `input_responses`/`request_state` -(independent resolvers are asked in one round; a resolver depending on another's -answer is asked in a later round). At <= 2025-11-25 it issues a synchronous -`elicitation/create` request mid-call. Only *elicited* outcomes are carried in -`request_state` across rounds (so the user is asked each question once). Resolver +take tool arguments by name, and take the `Context`. A resolver may return a +request marker (`Elicit[T]` to ask the user, `Sample` to sample the client's +LLM, `ListRoots` to fetch its roots); the framework injects the response. + +The transport follows the negotiated protocol: >= 2026-07-28 batches the requests +into an `InputRequiredResult` and resumes when the client retries with +`input_responses`/`request_state`; <= 2025-11-25 sends each standalone server-to-client +request mid-call. Only *asked* outcomes ride `request_state`, so each question is asked once. Resolver bodies may re-run on every round; a recorded outcome is consulted only when the body asks its question again, so a resolver's own computation always wins over anything the client echoes back in `request_state`. @@ -24,6 +21,8 @@ - `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. - `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the full outcome; the consumer branches on accept/decline/cancel. + +`Sample` and `ListRoots` have no decline arm; their consumers annotate the result type directly. """ from __future__ import annotations @@ -42,16 +41,30 @@ from mcp_types import ( MISSING_REQUIRED_CLIENT_CAPABILITY, ClientCapabilities, + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, ElicitationCapability, ElicitRequest, ElicitRequestFormParams, ElicitResult, FormElicitationCapability, + IncludeContext, InputRequest, InputRequests, InputRequiredResult, InputResponses, + ListRootsRequest, + ListRootsResult, MissingRequiredClientCapabilityErrorData, + ModelPreferences, + RootsCapability, + SamplingCapability, + SamplingMessage, + SamplingToolsCapability, + Tool, + ToolChoice, ) from mcp_types.version import is_version_at_least from pydantic import BaseModel, ValidationError @@ -67,8 +80,10 @@ from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError from mcp.server.request_state import compact_json +from mcp.server.validation import validate_tool_use_result_messages, wants_sampling_tools from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import MCPError +from mcp.shared.message import ServerMessageMetadata T = TypeVar("T", bound=BaseModel) @@ -103,6 +118,53 @@ def __init__(self, message: str, schema: type[T]) -> None: self.schema = schema +class Sample: + """A resolver's request to sample the client's LLM via `sampling/createMessage`. + + The framework injects a `CreateMessageResult` (`CreateMessageResultWithTools` when `tools` or + `tool_choice` are given, which also requires the client's `sampling.tools`); requires the + `sampling` capability. On >= 2026-07-28 the request must render identically across retry + rounds, and the sampled result rides `request_state` on every later round. `include_context` + other than "none" is deprecated in the draft spec. + """ + + def __init__( + self, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + ) -> None: + validate_tool_use_result_messages(messages) + self.params = CreateMessageRequestParams( + messages=messages, + max_tokens=max_tokens, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ) + + +class ListRoots: + """A resolver's request for the client's roots via `roots/list`; the framework injects the `ListRootsResult`.""" + + +_Marker = Elicit[Any] | Sample | ListRoots +"""The request markers a resolver may return.""" + + class _ParamPlan: """How to fill one resolver parameter, decided once at registration.""" @@ -221,19 +283,23 @@ def _contains_resolve(annotation: Any) -> bool: def _check_elicit_return(return_annotation: Any, name: str) -> None: - """Validate the `Elicit[...]` arms of a resolver's return annotation. + """Validate the request-marker arms of a resolver's return annotation. Raises: - InvalidSignature: If the annotation has more than one `Elicit[...]` arm; - a resolver asks one question - a second arm means it should be split. + InvalidSignature: If the annotation has more than one marker arm. """ - # A bare `Elicit[T]` is itself a candidate; a union contributes its members. candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) # Typing dedupes equal union members, so two arms here are genuinely distinct. - arms = [c for c in candidates if get_origin(c) is Elicit] + arms: list[Any] = [ + c + for c in candidates + # Origin guard for 3.10: `dict[str, Any]` passes `isinstance(c, type)` there and would crash `issubclass`. + if get_origin(c) is Elicit + or (get_origin(c) is None and isinstance(c, type) and issubclass(c, Elicit | Sample | ListRoots)) + ] if len(arms) > 1: raise InvalidSignature( - f"Resolver {name!r} return annotation has multiple Elicit arms; " + f"Resolver {name!r} return annotation has multiple Elicit/Sample/ListRoots arms; " "a resolver asks one question - split it into separate resolvers" ) @@ -360,9 +426,9 @@ class _Pending(Exception): class _Resolution: """Per-`tools/call` resolution state, shared across the DAG walk. - `input_required` selects the transport: at >= 2026-07-28 elicitations are + `input_required` selects the transport: at >= 2026-07-28 requests are batched into `pending` and surfaced as an `InputRequiredResult`; at older - revisions each `Elicit` is answered synchronously via `ctx.elicit`. + revisions each marker is answered synchronously over the back-channel. """ def __init__( @@ -384,10 +450,9 @@ def __init__( self.asked = decoded.asked # In-call dedup keyed by resolver identity (distinguishes two instances of # the same bound method); `persist` holds the wire-shaped record of each - # elicited outcome, keyed by its wire key - exactly what the next round's - # `request_state` carries. Entries are the client's own (validated) wire - # data, never re-derived from a model, so encode-restore is the identity. - # Pure resolvers are cheap to re-run each round and are not persisted. + # asked outcome, keyed by its wire key - exactly what the next round's `request_state` + # carries: the client's own validated content (elicitation) or the validated result's + # dump (sample/roots). Pure resolvers are cheap to re-run each round and are not persisted. self.cache: dict[Hashable, ElicitationResult[Any]] = {} self.persist: dict[str, _StateEntry] = {} self.pending: InputRequests = {} @@ -490,8 +555,8 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - if _is_elicit(result): - outcome = await _elicit(result, wire_key, res) + if _is_marker(result): + outcome = await _fulfil(result, wire_key, res) else: # A resolver may return any type (not just `BaseModel`), so accept it as the # outcome without validating against the schema bound. Plain outcomes are not @@ -502,18 +567,29 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul return outcome -async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> ElicitationResult[Any]: - """Turn a resolver's `Elicit` into an outcome via the negotiated transport.""" +async def _fulfil(marker: _Marker, key: str, res: _Resolution) -> ElicitationResult[Any]: + """Turn a resolver's request marker into an outcome via the negotiated transport.""" if not res.input_required: - return await res.context.elicit(elicit.message, elicit.schema) + # Gate wherever the request could actually be sent; otherwise the send path + # itself reports the failure. + if res.context.session.can_send_request: + _require_capability(res.context, marker, key) + if isinstance(marker, Elicit): + return await res.context.elicit(marker.message, marker.schema) + result = await res.context.session.send_request( + _render_request(marker), + _result_type(marker), + metadata=ServerMessageMetadata(related_request_id=res.context.request_id), + ) + return _accepted(result) - request = _elicit_request(elicit) + request = _render_request(marker) q = _request_digest(request) # A recorded outcome from a prior round is consulted only here, after the body # decided to ask, so a `request_state` entry can never stand in for a resolver's # own computation. A recorded outcome wins over a re-sent answer. - outcome = _restore_outcome(res, key, elicit.schema, q) + outcome = _restore_outcome(res, key, marker, q) if outcome is not None: return outcome @@ -524,16 +600,25 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio logger.info("Discarding the answer for resolver %r: the question changed since it was asked", key) answer = None if answer is None: - _require_form_elicitation(res.context, key) + _require_capability(res.context, marker, key) res.pending[key] = request raise _Pending + if not isinstance(marker, Elicit): + # A no-tool-use answer to a tools request parses as the plain result; validate against the marker's model. + wire = answer.model_dump(mode="json", by_alias=True, exclude_none=True) + try: + result = _result_type(marker).model_validate(wire) + except ValidationError as e: + raise ToolError(f"Resolver {key!r} received a response of the wrong kind") from e + res.persist[key] = _StateEntry(action="accept", data=wire, q=q) + return _accepted(result) if not isinstance(answer, ElicitResult): raise ToolError(f"Resolver {key!r} received a non-elicitation response") if answer.action == "accept": if answer.content is None: raise ToolError(f"Resolver {key!r} received an accepted elicitation with no content") try: - data = elicit.schema.model_validate(answer.content) + data = marker.schema.model_validate(answer.content) except ValidationError as e: raise ToolError( f"Resolver {key!r} received an accepted elicitation whose content does not match the requested schema" @@ -555,9 +640,8 @@ def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") -def _is_elicit(value: Any) -> TypeGuard[Elicit[Any]]: - """Runtime narrow of a resolver's return value to a (parameter-erased) `Elicit`.""" - return isinstance(value, Elicit) +def _is_marker(value: Any) -> TypeGuard[_Marker]: + return isinstance(value, Elicit | Sample | ListRoots) def _accepted(data: Any) -> AcceptedElicitation[Any]: @@ -578,35 +662,65 @@ def _uses_input_required(protocol_version: str | None) -> bool: return protocol_version is not None and is_version_at_least(protocol_version, _INPUT_REQUIRED_VERSION) -def _require_form_elicitation(context: Context[Any, Any], key: str) -> None: - """Assert the client declared form elicitation before queueing a question for it. +def _require_capability(context: Context[Any, Any], marker: _Marker, key: str) -> None: + """Assert the client declared the capability `marker`'s request needs. - The spec forbids sending an `input_requests` entry the client has not declared a - capability for. A bare `elicitation: {}` declaration (the only shape before modes - existed) counts as form support; an explicit url-only declaration does not. + A bare `elicitation: {}` (the only shape before modes existed) counts as form support; url-only does not. Raises: MCPError: With code `MISSING_REQUIRED_CLIENT_CAPABILITY` and a - `requiredCapabilities` payload when form elicitation is not declared. + `requiredCapabilities` payload when the capability is not declared. """ capabilities = context.client_capabilities - elicitation = capabilities.elicitation if capabilities is not None else None - if elicitation is not None and (elicitation.form is not None or elicitation.url is None): - return - data = MissingRequiredClientCapabilityErrorData( - required_capabilities=ClientCapabilities(elicitation=ElicitationCapability(form=FormElicitationCapability())) - ) + if isinstance(marker, Elicit): + elicitation = capabilities.elicitation if capabilities is not None else None + if elicitation is not None and (elicitation.form is not None or elicitation.url is None): + return + required = ClientCapabilities(elicitation=ElicitationCapability(form=FormElicitationCapability())) + name = "form elicitation" + elif isinstance(marker, Sample): + sampling = capabilities.sampling if capabilities is not None else None + wants_tools = wants_sampling_tools(marker.params.tools, marker.params.tool_choice) + if sampling is not None and (not wants_tools or sampling.tools is not None): + return + required = ClientCapabilities( + sampling=SamplingCapability(tools=SamplingToolsCapability() if wants_tools else None) + ) + name = "sampling.tools" if wants_tools else "sampling" + else: + if capabilities is not None and capabilities.roots is not None: + return + required = ClientCapabilities(roots=RootsCapability()) + name = "roots" + data = MissingRequiredClientCapabilityErrorData(required_capabilities=required) raise MCPError( code=MISSING_REQUIRED_CLIENT_CAPABILITY, - message=f"Client did not declare the form elicitation capability required by resolver {key!r}", + message=f"Client did not declare the {name} capability required by resolver {key!r}", data=data.model_dump(by_alias=True, mode="json", exclude_none=True), ) -def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: - """Render an `Elicit[T]` as the embedded `elicitation/create` request for `input_requests`.""" - json_schema = render_elicitation_schema(elicit.schema) - return ElicitRequest(params=ElicitRequestFormParams(message=elicit.message, requested_schema=json_schema)) +def _render_request(marker: _Marker) -> InputRequest: + """Render a marker as its wire request - the same shape on both transports.""" + if isinstance(marker, Elicit): + json_schema = render_elicitation_schema(marker.schema) + return ElicitRequest(params=ElicitRequestFormParams(message=marker.message, requested_schema=json_schema)) + if isinstance(marker, Sample): + return CreateMessageRequest(params=marker.params) + return ListRootsRequest() + + +def _result_type( + marker: Sample | ListRoots, +) -> type[CreateMessageResult] | type[CreateMessageResultWithTools] | type[ListRootsResult]: + """The result model a `Sample`/`ListRoots` response must validate against.""" + if isinstance(marker, ListRoots): + return ListRootsResult + return ( + CreateMessageResultWithTools + if wants_sampling_tools(marker.params.tools, marker.params.tool_choice) + else CreateMessageResult + ) class _StateEntry(BaseModel): @@ -660,34 +774,33 @@ def _decode_state(request_state: str | None) -> _State: def _encode_state(outcomes: Mapping[str, _StateEntry], asked: Mapping[str, str]) -> str: """Encode recorded outcomes and asked-question digests for the next round. - Outcome entries already hold the client's wire-shaped data exactly as it was - sent (and validated), so encoding is pure wrapping: encode-restore is the - identity. + Outcome entries are already wire-shaped, so encoding is pure wrapping. """ state = _State(v=_STATE_VERSION, outcomes=dict(outcomes), asked=dict(asked)) return compact_json(state.model_dump(mode="json")) -def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel]) -> ElicitationResult[Any]: - """Rebuild an `ElicitationResult` from a decoded `request_state` entry. +def _outcome_from_state(entry: _StateEntry, marker: _Marker) -> ElicitationResult[Any]: + """Rebuild an outcome from a decoded `request_state` entry. Raises: - ValidationError: If an accepted entry's data does not validate against - `schema` (the live `Elicit.schema` of the question being asked). + ValidationError: If the entry does not fit the live marker. """ - if entry.action == "decline": - return DeclinedElicitation() - if entry.action == "cancel": - return CancelledElicitation() - return _accepted(schema.model_validate(entry.data)) + if isinstance(marker, Elicit): + if entry.action == "decline": + return DeclinedElicitation() + if entry.action == "cancel": + return CancelledElicitation() + return _accepted(marker.schema.model_validate(entry.data)) + return _accepted(_result_type(marker).model_validate(entry.data)) -def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel], q: str) -> ElicitationResult[Any] | None: +def _restore_outcome(res: _Resolution, key: str, marker: _Marker, q: str) -> ElicitationResult[Any] | None: """Restore `key`'s recorded outcome from a prior round, or `None` when absent. - An entry pinned to a question digest other than `q`, or whose accepted - data fails validation against the live `schema`, is dropped as if no - progress was recorded, so the question is asked again. + An entry pinned to a question digest other than `q`, or that fails + validation against the live marker, is dropped as if no progress was + recorded, so the question is asked again. Carries the original decoded entry forward unchanged in `res.persist`: if a later resolver is still pending, the next round's `request_state` is built from @@ -701,7 +814,7 @@ def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel], q: str del res.state[key] return None try: - outcome = _outcome_from_state(entry, schema) + outcome = _outcome_from_state(entry, marker) except ValidationError: del res.state[key] return None @@ -712,6 +825,8 @@ def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel], q: str __all__ = [ "Resolve", "Elicit", + "Sample", + "ListRoots", "ElicitationResult", "AcceptedElicitation", "DeclinedElicitation", diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ca62fb9c8..0a61689eb 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -14,7 +14,7 @@ from typing_extensions import deprecated from mcp.server.connection import Connection -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages, wants_sampling_tools from mcp.shared.dispatcher import CallOptions, DispatchContext, ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning from mcp.shared.message import ServerMessageMetadata @@ -45,6 +45,11 @@ def client_params(self) -> types.InitializeRequestParams | None: """The client's `initialize` request params; `None` when no client info was supplied.""" return self._connection.client_params + @property + def can_send_request(self) -> bool: + """Whether this request's channel can currently deliver a server-initiated request.""" + return self._request_outbound.can_send_request + @property def protocol_version(self) -> str: """The protocol version this connection speaks. @@ -141,10 +146,10 @@ async def create_message( metadata: dict[str, Any] | None = None, model_preferences: types.ModelPreferences | None = None, tools: None = None, - tool_choice: types.ToolChoice | None = None, + tool_choice: None = None, related_request_id: types.RequestId | None = None, ) -> types.CreateMessageResult: - """Overload: Without tools, returns single content.""" + """Overload: Without tools or tool_choice, returns single content.""" ... @overload @@ -167,6 +172,26 @@ async def create_message( """Overload: With tools, returns array-capable content.""" ... + @overload + @deprecated("The sampling capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool] | None = None, + tool_choice: types.ToolChoice, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResultWithTools: + """Overload: With tool_choice, returns array-capable content.""" + ... + @deprecated("The sampling capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) async def create_message( self, @@ -231,7 +256,7 @@ async def create_message( ) metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) - if tools is not None: + if wants_sampling_tools(tools, tool_choice): return await self.send_request( request=request, result_type=types.CreateMessageResultWithTools, diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index fd16beb95..a281f4d08 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -26,6 +26,11 @@ def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> b return True +def wants_sampling_tools(tools: list[Tool] | None, tool_choice: ToolChoice | None) -> bool: + """Whether a sampling request is tools-mode: `sampling.tools` gated, array-capable answer.""" + return tools is not None or tool_choice is not None + + def validate_sampling_tools( client_caps: ClientCapabilities | None, tools: list[Tool] | None, @@ -41,7 +46,7 @@ def validate_sampling_tools( Raises: MCPError: If tools/tool_choice are provided but client doesn't support them """ - if tools is not None or tool_choice is not None: + if wants_sampling_tools(tools, tool_choice): if not check_sampling_tools_capability(client_caps): raise MCPError(code=INVALID_PARAMS, message="Client does not support sampling tools capability") diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index ca59b56af..14e8fe1c2 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -98,7 +98,7 @@ async def sample( metadata: dict[str, Any] | None = None, model_preferences: ModelPreferences | None = None, tools: None = None, - tool_choice: ToolChoice | None = None, + tool_choice: None = None, meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResult: ... @@ -120,6 +120,24 @@ async def sample( meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResultWithTools: ... + @overload + @deprecated("The sampling capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) + async def sample( + self, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResultWithTools: ... @deprecated("The sampling capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) async def sample( self, @@ -157,7 +175,7 @@ async def sample( tool_choice=tool_choice, ) result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) - if tools is not None: + if tools is not None or tool_choice is not None: return CreateMessageResultWithTools.model_validate(result, by_name=False) return CreateMessageResult.model_validate(result, by_name=False) diff --git a/tests/docs_src/test_dependencies.py b/tests/docs_src/test_dependencies.py index 6dba9277e..8474d55e4 100644 --- a/tests/docs_src/test_dependencies.py +++ b/tests/docs_src/test_dependencies.py @@ -4,9 +4,9 @@ import pytest from inline_snapshot import snapshot -from mcp_types import ElicitRequestParams, ElicitResult, TextContent +from mcp_types import CreateMessageRequestParams, CreateMessageResult, ElicitRequestParams, ElicitResult, TextContent -from docs_src.dependencies import tutorial001, tutorial002, tutorial003 +from docs_src.dependencies import tutorial001, tutorial002, tutorial003, tutorial004 from mcp import Client from mcp.client import ClientRequestContext @@ -138,3 +138,21 @@ async def decline(context: ClientRequestContext, params: ElicitRequestParams) -> assert result.content[0].text == ( "Error executing tool order_book: Resolver for parameter 'backorder' could not resolve: elicitation was decline" ) + + +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_a_resolver_can_sample_the_clients_llm(mode: Literal["legacy", "auto"]) -> None: + """tutorial004: `suggest_title` runs through the client's sampling callback on both eras.""" + prompts: list[str] = [] + + async def sampler(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: + content = params.messages[0].content + assert isinstance(content, TextContent) + prompts.append(content.text) + return CreateMessageResult(role="assistant", content=TextContent(type="text", text="Dune"), model="m") + + async with Client(tutorial004.mcp, mode=mode, sampling_callback=sampler) as client: + result = await client.call_tool("recommend_book", {"genre": "sci-fi"}) + + assert result.content == [TextContent(type="text", text="Today's sci-fi pick: Dune")] + assert prompts == ["Suggest one sci-fi book title. Answer with the title only."] diff --git a/tests/docs_src/test_sampling_and_roots.py b/tests/docs_src/test_sampling_and_roots.py new file mode 100644 index 000000000..7e4b9aea3 --- /dev/null +++ b/tests/docs_src/test_sampling_and_roots.py @@ -0,0 +1,62 @@ +"""`docs/handlers/sampling-and-roots.md`: every claim the page makes, proved against the real SDK.""" + +from typing import Literal + +import pytest +from mcp_types import ( + MISSING_REQUIRED_CLIENT_CAPABILITY, + CreateMessageRequestParams, + CreateMessageResult, + ListRootsResult, + Root, + TextContent, +) +from pydantic import FileUrl + +from docs_src.sampling_and_roots import tutorial001, tutorial002 +from mcp import Client +from mcp.client import ClientRequestContext +from mcp.shared.exceptions import MCPError + +pytestmark = [pytest.mark.anyio, pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning")] + + +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_a_sampling_dependency_receives_the_clients_completion(mode: Literal["legacy", "auto"]) -> None: + """tutorial001: `draft_blurb` runs through the client's model on both protocol versions.""" + prompts: list[str] = [] + + async def sampler(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: + content = params.messages[0].content + assert isinstance(content, TextContent) + prompts.append(content.text) + return CreateMessageResult( + role="assistant", content=TextContent(type="text", text="A desert planet holds the key."), model="m" + ) + + async with Client(tutorial001.mcp, mode=mode, sampling_callback=sampler) as client: + result = await client.call_tool("blurb", {"title": "Dune"}) + + assert result.content == [TextContent(type="text", text="A desert planet holds the key.")] + assert prompts == ["Write a one-sentence blurb for the book 'Dune'."] + + +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_a_roots_dependency_receives_the_clients_folders(mode: Literal["legacy", "auto"]) -> None: + """tutorial002: `workspace_roots` fetches the client's roots list.""" + + async def client_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[Root(uri=FileUrl("file:///workspace/catalog"), name="catalog")]) + + async with Client(tutorial002.mcp, mode=mode, list_roots_callback=client_roots) as client: + result = await client.call_tool("catalog_folder", {}) + + assert result.content == [TextContent(type="text", text="file:///workspace/catalog")] + + +async def test_an_undeclared_capability_fails_before_a_request_is_sent() -> None: + """The page's gate claim: no `sampling` capability means a -32021 protocol error.""" + async with Client(tutorial001.mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("blurb", {"title": "Dune"}) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index c28f12481..0ce267024 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -10,19 +10,36 @@ from mcp_types import ( MISSING_REQUIRED_CLIENT_CAPABILITY, CallToolResult, + CreateMessageRequest, + CreateMessageRequestParams, CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequest, ElicitRequestFormParams, ElicitRequestParams, ElicitResult, InputRequiredResult, InputResponses, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + ListRootsResult, + Root, + SamplingCapability, + SamplingMessage, + SamplingToolsCapability, TextContent, + ToolChoice, ) -from pydantic import BaseModel, Field, ValidationError, create_model +from mcp_types import ( + Tool as SamplingTool, +) +from pydantic import BaseModel, Field, FileUrl, ValidationError, create_model from typing_extensions import TypeAliasType from mcp import Client, InputRequiredRoundsExceededError from mcp.client import ClientRequestContext +from mcp.client._memory import InMemoryTransport from mcp.server.context import ServerRequestContext from mcp.server.mcpserver import ( AcceptedElicitation, @@ -32,18 +49,20 @@ DeclinedElicitation, Elicit, ElicitationResult, + ListRoots, MCPServer, RequestStateBoundary, RequestStateSecurity, Resolve, + Sample, ) from mcp.server.mcpserver.exceptions import InvalidSignature from mcp.server.mcpserver.resolve import ( _check_elicit_return, _decode_state, - _elicit_request, _encode_state, _outcome_from_state, + _render_request, _request_digest, _resolver_key, _state_key, @@ -54,11 +73,12 @@ ) from mcp.server.mcpserver.tools.base import Tool from mcp.shared.exceptions import MCPError +from mcp.shared.message import SessionMessage def _question_digest(elicit: Elicit[Any]) -> str: - """The digest `_elicit` pins: the rendered request the client would be shown.""" - return _request_digest(_elicit_request(elicit)) + """The digest `_fulfil` pins: the rendered request the client would be shown.""" + return _request_digest(_render_request(elicit)) class Login(BaseModel): @@ -425,7 +445,7 @@ async def ambiguous(ctx: Context) -> Elicit[Login] | Elicit[Confirm]: async def tool(login: Annotated[Login, Resolve(ambiguous)]) -> str: return login.username # pragma: no cover - with pytest.raises(InvalidSignature, match="multiple Elicit arms"): + with pytest.raises(InvalidSignature, match="multiple Elicit/Sample/ListRoots arms"): Tool.from_function(tool) @@ -938,15 +958,16 @@ def test_state_round_trips_accept_decline_cancel(): assert decoded == entries # encode-restore is the identity on the stored entries assert state.asked == {"e": "asked-digest"} - accepted = _outcome_from_state(decoded["a"], Login) + ask = Elicit("q", Login) + accepted = _outcome_from_state(decoded["a"], ask) assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") # Decline/cancel entries carry no data; the schema is not consulted for them. - assert isinstance(_outcome_from_state(decoded["b"], Login), DeclinedElicitation) - assert isinstance(_outcome_from_state(decoded["c"], Login), CancelledElicitation) + assert isinstance(_outcome_from_state(decoded["b"], ask), DeclinedElicitation) + assert isinstance(_outcome_from_state(decoded["c"], ask), CancelledElicitation) # An accepted restore always validates against the question's live schema - # data that doesn't fit is rejected, never passed through raw. with pytest.raises(ValidationError): - _outcome_from_state(decoded["d"], Login) + _outcome_from_state(decoded["d"], ask) def test_check_elicit_return_allows_one_arm_and_rejects_two(): @@ -955,7 +976,7 @@ def test_check_elicit_return_allows_one_arm_and_rejects_two(): _check_elicit_return(Login, "r") # no Elicit arm _check_elicit_return(None, "r") # unannotated # A resolver asks one question: two distinct Elicit arms mean it should be split. - with pytest.raises(InvalidSignature, match="'r' return annotation has multiple Elicit arms"): + with pytest.raises(InvalidSignature, match="'r' return annotation has multiple Elicit/Sample/ListRoots arms"): _check_elicit_return(Elicit[Login] | Elicit[Confirm], "r") @@ -2365,3 +2386,397 @@ async def act(go: Annotated[Confirm, Resolve(ask)]) -> str: assert isinstance(final, CallToolResult) assert isinstance(final.content[0], TextContent) assert final.content[0].text == "went:True" + + +# --- Sample / ListRoots markers --- + + +async def _sample_never( # pragma: no cover - declares the capability; never invoked + context: ClientRequestContext, params: CreateMessageRequestParams +) -> CreateMessageResult: + raise AssertionError("should not be called") + + +async def _roots_never(context: ClientRequestContext) -> ListRootsResult: # pragma: no cover - see _sample_never + raise AssertionError("should not be called") + + +def _sample_capital(ctx: Context) -> Sample: + return Sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="Capital of France?"))], + max_tokens=16, + ) + + +@pytest.mark.anyio +@pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning") +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_sample_resolver_injects_result(mode: Literal["legacy", "auto"]): + # The marker form is the 2026-blessed carrier: no SEP-2577 deprecation warning on either mode. + mcp = MCPServer(name="Sampler", request_state_security=RequestStateSecurity.ephemeral()) + prompts: list[str] = [] + + async def sampler(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: + content = params.messages[0].content + assert isinstance(content, TextContent) + prompts.append(content.text) + return CreateMessageResult(role="assistant", content=TextContent(type="text", text="Paris"), model="m") + + @mcp.tool() + async def capital(answer: Annotated[CreateMessageResult, Resolve(_sample_capital)]) -> str: + assert isinstance(answer.content, TextContent) + return answer.content.text + + async with Client(mcp, mode=mode, sampling_callback=sampler) as client: + assert await _text(client, "capital", {}) == "Paris" + assert prompts == ["Capital of France?"] + + +@pytest.mark.anyio +@pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning") +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_list_roots_resolver_injects_result(mode: Literal["legacy", "auto"]): + mcp = MCPServer(name="Rooted", request_state_security=RequestStateSecurity.ephemeral()) + + async def client_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[Root(uri=FileUrl("file:///workspace"))]) + + def fetch_roots(ctx: Context) -> ListRoots: + return ListRoots() + + @mcp.tool() + async def workspace(roots: Annotated[ListRootsResult, Resolve(fetch_roots)]) -> str: + return str(len(roots.roots)) + + async with Client(mcp, mode=mode, list_roots_callback=client_roots) as client: + assert await _text(client, "workspace", {}) == "1" + + +@pytest.mark.anyio +async def test_mixed_kinds_batch_into_one_round(): + mcp = MCPServer(name="Mixed", request_state_security=RequestStateSecurity.ephemeral()) + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def fetch_roots(ctx: Context) -> ListRoots: + return ListRoots() + + @mcp.tool() + async def combo( + login: Annotated[Login, Resolve(ask_name)], + answer: Annotated[CreateMessageResult, Resolve(_sample_capital)], + roots: Annotated[ListRootsResult, Resolve(fetch_roots)], + ) -> str: + assert isinstance(answer.content, TextContent) + return f"{login.username}/{answer.content.text}/{len(roots.roots)}" + + async with Client( + mcp, elicitation_callback=_never, sampling_callback=_sample_never, list_roots_callback=_roots_never + ) as client: + first = await client.session.call_tool("combo", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + kinds = sorted(type(req).__name__ for req in first.input_requests.values()) + assert kinds == ["CreateMessageRequest", "ElicitRequest", "ListRootsRequest"] + responses: InputResponses = {} + for key, req in first.input_requests.items(): + if isinstance(req, ElicitRequest): + responses[key] = ElicitResult(action="accept", content={"username": "octocat"}) + elif isinstance(req, CreateMessageRequest): + responses[key] = CreateMessageResult( + role="assistant", content=TextContent(type="text", text="hey"), model="m" + ) + else: + responses[key] = ListRootsResult(roots=[]) + final = await client.session.call_tool( + "combo", {}, input_responses=responses, request_state=first.request_state, allow_input_required=True + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat/hey/0" + + +@pytest.mark.anyio +async def test_sampling_tool_without_client_capability_is_a_protocol_error(): + mcp = MCPServer(name="NoSamplingCapability", request_state_security=RequestStateSecurity.ephemeral()) + + @mcp.tool() + async def capital(answer: Annotated[CreateMessageResult, Resolve(_sample_capital)]) -> str: + return "unreachable" # pragma: no cover + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.session.call_tool("capital", {}, allow_input_required=True) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert exc_info.value.error.data is not None + assert "sampling" in exc_info.value.error.data["requiredCapabilities"] + + +@pytest.mark.anyio +async def test_roots_tool_without_client_capability_is_a_protocol_error(): + mcp = MCPServer(name="NoRootsCapability", request_state_security=RequestStateSecurity.ephemeral()) + + def fetch_roots(ctx: Context) -> ListRoots: + return ListRoots() + + @mcp.tool() + async def workspace(roots: Annotated[ListRootsResult, Resolve(fetch_roots)]) -> str: + return "unreachable" # pragma: no cover + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.session.call_tool("workspace", {}, allow_input_required=True) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert exc_info.value.error.data is not None + assert "roots" in exc_info.value.error.data["requiredCapabilities"] + + +@pytest.mark.anyio +async def test_legacy_eliciting_tool_without_capability_is_a_protocol_error(): + # Same egress gate as the input_requests leg; the session stays usable after the refusal. + mcp = MCPServer(name="LegacyGate", request_state_security=RequestStateSecurity.ephemeral()) + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + @mcp.tool() + def plain() -> str: + return "ok" + + async with Client(mcp, mode="legacy") as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("tool", {}) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert await _text(client, "plain", {}) == "ok" + + +def _ask_with_tools(ctx: Context) -> Sample: + return Sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="2+2?"))], + max_tokens=16, + tools=[SamplingTool(name="calc", input_schema={"type": "object"})], + ) + + +def _ask_with_tool_choice(ctx: Context) -> Sample: + return Sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="2+2?"))], + max_tokens=16, + tool_choice=ToolChoice(mode="none"), + ) + + +@pytest.mark.anyio +@pytest.mark.parametrize("ask", [_ask_with_tools, _ask_with_tool_choice]) +async def test_sample_tools_require_the_tools_subcapability(ask: Callable[[Context], Sample]): + mcp = MCPServer(name="NoToolsSubcapability", request_state_security=RequestStateSecurity.ephemeral()) + + @mcp.tool() + async def calc(answer: Annotated[CreateMessageResultWithTools, Resolve(ask)]) -> str: + return "unreachable" # pragma: no cover + + # The callback declares base `sampling` but not `sampling.tools`. + async with Client(mcp, sampling_callback=_sample_never) as client: + with pytest.raises(MCPError) as exc_info: + await client.session.call_tool("calc", {}, allow_input_required=True) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert exc_info.value.error.data is not None + assert exc_info.value.error.data["requiredCapabilities"] == {"sampling": {"tools": {}}} + + +@pytest.mark.anyio +async def test_sample_with_tools_round_trips_with_declared_subcapability(): + mcp = MCPServer(name="ToolsSampling", request_state_security=RequestStateSecurity.ephemeral()) + + async def sampler( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + assert params.tools is not None and params.tools[0].name == "calc" + return CreateMessageResultWithTools(role="assistant", content=[TextContent(type="text", text="4")], model="m") + + @mcp.tool() + async def calc(answer: Annotated[CreateMessageResultWithTools, Resolve(_ask_with_tools)]) -> str: + assert isinstance(answer.content, list) and isinstance(answer.content[0], TextContent) + return answer.content[0].text + + async with Client( + mcp, + sampling_callback=sampler, + sampling_capabilities=SamplingCapability(tools=SamplingToolsCapability()), + ) as client: + assert await _text(client, "calc", {}) == "4" + + +@pytest.mark.anyio +async def test_no_tool_use_answer_to_a_tools_request_is_accepted(): + # The answer parses off the wire as plain CreateMessageResult but must inject as CreateMessageResultWithTools. + mcp = MCPServer(name="NoToolUse", request_state_security=RequestStateSecurity.ephemeral()) + + @mcp.tool() + async def calc(answer: Annotated[CreateMessageResultWithTools, Resolve(_ask_with_tools)]) -> str: + assert isinstance(answer, CreateMessageResultWithTools) + assert isinstance(answer.content, TextContent) + return answer.content.text + + async with Client( + mcp, + sampling_callback=_sample_never, + sampling_capabilities=SamplingCapability(tools=SamplingToolsCapability()), + ) as client: + first = await client.session.call_tool("calc", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (key,) = first.input_requests + final = await client.session.call_tool( + "calc", + {}, + input_responses={ + key: CreateMessageResult(role="assistant", content=TextContent(type="text", text="4"), model="m") + }, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert not final.is_error + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "4" + + +@pytest.mark.anyio +async def test_sample_outcome_persists_across_rounds(): + # The confirm arm depends on the sample, forcing extra rounds that restore the result instead of re-sampling. + mcp = MCPServer(name="Chain", request_state_security=RequestStateSecurity.ephemeral()) + samples = 0 + + async def sampler(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: + nonlocal samples + samples += 1 + return CreateMessageResult(role="assistant", content=TextContent(type="text", text="Paris"), model="m") + + async def confirm( + answer: Annotated[CreateMessageResult, Resolve(_sample_capital)], ctx: Context + ) -> Elicit[Confirm]: + return Elicit("Accept the model's answer?", Confirm) + + @mcp.tool() + async def tool( + ok: Annotated[Confirm, Resolve(confirm)], + answer: Annotated[CreateMessageResult, Resolve(_sample_capital)], + ) -> str: + assert isinstance(answer.content, TextContent) + return f"{answer.content.text}:{ok.ok}" + + async with Client(mcp, sampling_callback=sampler, elicitation_callback=_accept({"ok": True})) as client: + assert await _text(client, "tool", {}) == "Paris:True" + assert samples == 1 + + +@pytest.mark.anyio +async def test_wrong_kind_response_for_sample_raises(): + mcp = MCPServer(name="WrongKind", request_state_security=RequestStateSecurity.ephemeral()) + + @mcp.tool() + async def capital(answer: Annotated[CreateMessageResult, Resolve(_sample_capital)]) -> str: + return "unreachable" # pragma: no cover + + async with Client(mcp, sampling_callback=_sample_never) as client: + first = await client.session.call_tool("capital", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (key,) = first.input_requests + final = await client.session.call_tool( + "capital", + {}, + input_responses={key: ElicitResult(action="accept", content={"x": "y"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert final.is_error + assert isinstance(final.content[0], TextContent) + assert "wrong kind" in final.content[0].text + + +def test_mixed_marker_arms_raise_at_registration(): + async def ambiguous(ctx: Context) -> Sample | Elicit[Login]: + raise NotImplementedError # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(ambiguous)]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="multiple Elicit/Sample/ListRoots arms"): + Tool.from_function(tool) + + +def test_marker_union_with_generic_alias_member_registers(): + # dict[str, Any] passes isinstance(c, type) on Python 3.10; the arm filter must not feed it to issubclass. + async def maybe_ask(ctx: Context) -> Sample | dict[str, Any]: + raise NotImplementedError # pragma: no cover + + async def tool(answer: Annotated[CreateMessageResult, Resolve(maybe_ask)]) -> str: + return "ok" # pragma: no cover + + Tool.from_function(tool) + + +def test_decline_entry_for_a_sample_marker_is_invalid(): + # Decline outcomes exist only for elicitations; for a Sample the entry's None data fails validation. + with pytest.raises(ValidationError): + _outcome_from_state(_StateEntry(action="decline"), _sample_capital(cast(Context, None))) + + +@pytest.mark.anyio +async def test_bare_initialized_session_is_still_gated(): + # notifications/initialized alone commits the handshake: a live back-channel, no declared capabilities. + mcp = MCPServer(name="BareInit", request_state_security=RequestStateSecurity.ephemeral()) + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + async with InMemoryTransport(mcp) as (read, write): + await write.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))) + await write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "tool", "arguments": {}}) + ) + ) + with anyio.fail_after(5): + message = await read.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, JSONRPCError) + assert message.message.error.code == MISSING_REQUIRED_CLIENT_CAPABILITY + + +@pytest.mark.anyio +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_tool_choice_only_sample_validates_as_tools_mode(mode: Literal["legacy", "auto"]): + # Gate and answer model share one predicate: tool_choice alone is tools-mode, + # so a single-content answer still validates (WithTools accepts both shapes). + mcp = MCPServer(name="ToolChoiceOnly", request_state_security=RequestStateSecurity.ephemeral()) + + async def sampler(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: + assert params.tool_choice is not None and params.tools is None + return CreateMessageResult(role="assistant", content=TextContent(type="text", text="4"), model="m") + + @mcp.tool() + async def calc(answer: Annotated[CreateMessageResultWithTools, Resolve(_ask_with_tool_choice)]) -> str: + assert isinstance(answer, CreateMessageResultWithTools) + assert isinstance(answer.content, TextContent) + return answer.content.text + + async with Client( + mcp, + mode=mode, + sampling_callback=sampler, + sampling_capabilities=SamplingCapability(tools=SamplingToolsCapability()), + ) as client: + assert await _text(client, "calc", {}) == "4" diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 25b8257eb..49e3b4615 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -225,6 +225,21 @@ async def test_create_message_with_tools_returns_with_tools_result(): assert params is not None and params["tools"][0]["name"] == "t" +@pytest.mark.anyio +async def test_create_message_with_tool_choice_only_returns_with_tools_result(): + # tool_choice alone is tools-mode: the answer may carry array content. + outbound = StubOutbound(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) + session = _make_session( + outbound, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + ) + result = await session.create_message( # pyright: ignore[reportDeprecated] + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], + max_tokens=10, + tool_choice=types.ToolChoice(mode="none"), + ) + assert isinstance(result, types.CreateMessageResultWithTools) + + def test_check_client_capability_delegates_to_connection(): outbound = StubOutbound() session = _make_session(outbound, capabilities=ClientCapabilities(sampling=SamplingCapability())) diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 2fc92e2c8..0bc990f51 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -18,6 +18,7 @@ SamplingMessage, TextContent, Tool, + ToolChoice, ) from mcp.shared.dispatcher import DispatchContext @@ -91,6 +92,21 @@ async def test_peer_sample_with_tools_returns_with_tools_result(): assert isinstance(result, CreateMessageResultWithTools) +@pytest.mark.anyio +async def test_peer_sample_with_tool_choice_only_returns_with_tools_result(): + # tool_choice alone is tools-mode: the answer may carry array content. + rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = ClientPeer(client) + with anyio.fail_after(5): + result = await peer.sample( # pyright: ignore[reportDeprecated] + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], + max_tokens=5, + tool_choice=ToolChoice(mode="none"), + ) + assert isinstance(result, CreateMessageResultWithTools) + + @pytest.mark.anyio async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): rec = _Recorder({"action": "accept", "content": {"name": "Max"}})