diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7e79c94326..0e917aefd0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -148,7 +148,7 @@ def __init__( | None = None, on_call_tool: Callable[ [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], - Awaitable[types.CallToolResult], + Awaitable[types.CallToolResult | types.InputRequiredResult], ] | None = None, on_list_resources: Callable[ @@ -163,7 +163,7 @@ def __init__( | None = None, on_read_resource: Callable[ [ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams], - Awaitable[types.ReadResourceResult], + Awaitable[types.ReadResourceResult | types.InputRequiredResult], ] | None = None, on_subscribe_resource: Callable[ @@ -176,6 +176,11 @@ def __init__( Awaitable[types.EmptyResult], ] | None = None, + on_subscriptions_listen: Callable[ + [ServerRequestContext[LifespanResultT], types.SubscriptionsListenRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, on_list_prompts: Callable[ [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], Awaitable[types.ListPromptsResult], @@ -183,7 +188,7 @@ def __init__( | None = None, on_get_prompt: Callable[ [ServerRequestContext[LifespanResultT], types.GetPromptRequestParams], - Awaitable[types.GetPromptResult], + Awaitable[types.GetPromptResult | types.InputRequiredResult], ] | None = None, on_completion: Callable[ @@ -242,6 +247,7 @@ def __init__( ("resources/read", types.ReadResourceRequestParams, on_read_resource), ("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource), ("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource), + ("subscriptions/listen", types.SubscriptionsListenRequestParams, on_subscriptions_listen), ("tools/list", types.PaginatedRequestParams, on_list_tools), ("tools/call", types.CallToolRequestParams, on_call_tool), ("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level), diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 82b4a084d5..f0b13f0bd0 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -16,9 +16,10 @@ Field, FileUrl, TypeAdapter, + model_validator, ) from pydantic.alias_generators import to_camel -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, Self, TypedDict from mcp.types.jsonrpc import RequestId @@ -2052,7 +2053,7 @@ class InputRequiredResult(Result): (`tools/call`, `prompts/get`, `resources/read`). The client fulfills `input_requests` and retries the original request, carrying the responses and the echoed `request_state`. At least one of those two fields is - present on the wire (spec MUST; not enforced by the model). + present on the wire (spec MUST). """ result_type: Literal["input_required"] = "input_required" @@ -2064,6 +2065,12 @@ class InputRequiredResult(Result): request_state: str | None = None """Opaque state to pass back verbatim when the client retries the original request.""" + @model_validator(mode="after") + def _require_one_field(self) -> Self: + if self.input_requests is None and self.request_state is None: + raise ValueError("InputRequiredResult requires at least one of input_requests or request_state") + return self + # Forward refs to InputResponses; rebuild at import time rather than first use. InputResponseRequestParams.model_rebuild() diff --git a/tests/test_types.py b/tests/test_types.py index 3756bd893d..dffbc918dc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -2,6 +2,7 @@ import pytest from inline_snapshot import snapshot +from pydantic import ValidationError from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -436,4 +437,13 @@ def test_empty_result_dumps_result_type_only_when_explicitly_tagged(): def test_input_required_result_dumps_its_discriminating_tag(): - assert _wire_dump(InputRequiredResult()) == snapshot({"resultType": "input_required"}) + assert _wire_dump(InputRequiredResult(request_state="s")) == snapshot( + {"resultType": "input_required", "requestState": "s"} + ) + + +def test_input_required_result_requires_at_least_one_of_input_requests_or_request_state(): + with pytest.raises(ValidationError): + InputRequiredResult() + assert InputRequiredResult(input_requests={}).request_state is None + assert InputRequiredResult(request_state="s").input_requests is None