Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
| None = None,
on_call_tool: Callable[
[ServerRequestContext[LifespanResultT], types.CallToolRequestParams],
Awaitable[types.CallToolResult],
Awaitable[types.CallToolResult | types.InputRequiredResult],
]
| None = None,
on_list_resources: Callable[
Expand All @@ -163,7 +163,7 @@ def __init__(
| None = None,
on_read_resource: Callable[
[ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams],
Awaitable[types.ReadResourceResult],
Awaitable[types.ReadResourceResult | types.InputRequiredResult],
]
| None = None,
on_subscribe_resource: Callable[
Expand All @@ -176,14 +176,19 @@ def __init__(
Awaitable[types.EmptyResult],
]
| None = None,
on_subscriptions_listen: Callable[
[ServerRequestContext[LifespanResultT], types.SubscriptionsListenRequestParams],
Awaitable[types.EmptyResult],
]
| None = None,
on_list_prompts: Callable[
[ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None],
Awaitable[types.ListPromptsResult],
]
| None = None,
on_get_prompt: Callable[
[ServerRequestContext[LifespanResultT], types.GetPromptRequestParams],
Awaitable[types.GetPromptResult],
Awaitable[types.GetPromptResult | types.InputRequiredResult],
]
| None = None,
on_completion: Callable[
Expand Down Expand Up @@ -242,6 +247,7 @@ def __init__(
("resources/read", types.ReadResourceRequestParams, on_read_resource),
("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource),
("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource),
("subscriptions/listen", types.SubscriptionsListenRequestParams, on_subscriptions_listen),
("tools/list", types.PaginatedRequestParams, on_list_tools),
("tools/call", types.CallToolRequestParams, on_call_tool),
("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level),
Expand Down
11 changes: 9 additions & 2 deletions src/mcp/types/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
Field,
FileUrl,
TypeAdapter,
model_validator,
)
from pydantic.alias_generators import to_camel
from typing_extensions import NotRequired, TypedDict
from typing_extensions import NotRequired, Self, TypedDict

from mcp.types.jsonrpc import RequestId

Expand Down Expand Up @@ -2052,7 +2053,7 @@ class InputRequiredResult(Result):
(`tools/call`, `prompts/get`, `resources/read`). The client fulfills
`input_requests` and retries the original request, carrying the responses
and the echoed `request_state`. At least one of those two fields is
present on the wire (spec MUST; not enforced by the model).
present on the wire (spec MUST).
"""

result_type: Literal["input_required"] = "input_required"
Expand All @@ -2064,6 +2065,12 @@ class InputRequiredResult(Result):
request_state: str | None = None
"""Opaque state to pass back verbatim when the client retries the original request."""

@model_validator(mode="after")
def _require_one_field(self) -> Self:
if self.input_requests is None and self.request_state is None:
raise ValueError("InputRequiredResult requires at least one of input_requests or request_state")
return self
Comment thread
claude[bot] marked this conversation as resolved.


# Forward refs to InputResponses; rebuild at import time rather than first use.
InputResponseRequestParams.model_rebuild()
Expand Down
12 changes: 11 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from inline_snapshot import snapshot
from pydantic import ValidationError

from mcp.types import (
LATEST_PROTOCOL_VERSION,
Expand Down Expand Up @@ -436,4 +437,13 @@ def test_empty_result_dumps_result_type_only_when_explicitly_tagged():


def test_input_required_result_dumps_its_discriminating_tag():
assert _wire_dump(InputRequiredResult()) == snapshot({"resultType": "input_required"})
assert _wire_dump(InputRequiredResult(request_state="s")) == snapshot(
{"resultType": "input_required", "requestState": "s"}
)


def test_input_required_result_requires_at_least_one_of_input_requests_or_request_state():
with pytest.raises(ValidationError):
InputRequiredResult()
assert InputRequiredResult(input_requests={}).request_state is None
assert InputRequiredResult(request_state="s").input_requests is None
Loading