Skip to content

Commit 03681ed

Browse files
authored
Client call_tool: input_responses/request_state retry params; InputRequiredResult via allow_input_required (#2968)
1 parent 96bf22e commit 03681ed

6 files changed

Lines changed: 239 additions & 13 deletions

File tree

docs/migration.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer`
364364

365365
`Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it.
366366

367+
### `call_tool` can return `InputRequiredResult` (opt-in)
368+
369+
For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. By default `call_tool` (on `ClientSession`, `Client`, and `ClientSessionGroup`) still returns `CallToolResult` and raises `RuntimeError` if the server requests input. Pass `allow_input_required=True` to receive the `InputRequiredResult` instead, then retry with `input_responses=` / `request_state=`.
370+
367371
### `McpError` renamed to `MCPError`
368372

369373
The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK.

src/mcp/client/client.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Awaitable, Callable, Mapping
66
from contextlib import AsyncExitStack
77
from dataclasses import KW_ONLY, dataclass, field
8-
from typing import Any, Literal, TypeVar
8+
from typing import Any, Literal, TypeVar, overload
99

1010
import anyio
1111
from typing_extensions import deprecated
@@ -30,6 +30,8 @@
3030
EmptyResult,
3131
GetPromptResult,
3232
Implementation,
33+
InputRequiredResult,
34+
InputResponses,
3335
ListPromptsResult,
3436
ListResourcesResult,
3537
ListResourceTemplatesResult,
@@ -374,33 +376,79 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
374376
"""Unsubscribe from resource updates."""
375377
return await self.session.unsubscribe_resource(uri, meta=meta)
376378

379+
@overload
377380
async def call_tool(
378381
self,
379382
name: str,
380383
arguments: dict[str, Any] | None = None,
381384
read_timeout_seconds: float | None = None,
382385
progress_callback: ProgressFnT | None = None,
383386
*,
387+
input_responses: InputResponses | None = None,
388+
request_state: str | None = None,
384389
meta: RequestParamsMeta | None = None,
385-
) -> CallToolResult:
390+
allow_input_required: Literal[False] = False,
391+
) -> CallToolResult: ...
392+
393+
@overload
394+
async def call_tool(
395+
self,
396+
name: str,
397+
arguments: dict[str, Any] | None = None,
398+
read_timeout_seconds: float | None = None,
399+
progress_callback: ProgressFnT | None = None,
400+
*,
401+
input_responses: InputResponses | None = None,
402+
request_state: str | None = None,
403+
meta: RequestParamsMeta | None = None,
404+
allow_input_required: bool,
405+
) -> CallToolResult | InputRequiredResult: ...
406+
407+
async def call_tool(
408+
self,
409+
name: str,
410+
arguments: dict[str, Any] | None = None,
411+
read_timeout_seconds: float | None = None,
412+
progress_callback: ProgressFnT | None = None,
413+
*,
414+
input_responses: InputResponses | None = None,
415+
request_state: str | None = None,
416+
meta: RequestParamsMeta | None = None,
417+
allow_input_required: bool = False,
418+
) -> CallToolResult | InputRequiredResult:
386419
"""Call a tool on the server.
387420
388421
Args:
389422
name: The name of the tool to call
390423
arguments: Arguments to pass to the tool
391424
read_timeout_seconds: Timeout for the tool call
392425
progress_callback: Callback for progress updates
426+
input_responses: Responses to a prior `InputRequiredResult.input_requests`
427+
request_state: Opaque state echoed from a prior `InputRequiredResult`
393428
meta: Additional metadata for the request
429+
allow_input_required: When ``False`` (default), an `InputRequiredResult`
430+
from the server raises `RuntimeError`; when ``True``, it is returned
431+
so the caller can resolve the requests and retry.
394432
395433
Returns:
396-
The tool result.
434+
The tool result. When ``allow_input_required=True``, may instead be an
435+
`InputRequiredResult` carrying the server's input requests and opaque
436+
``request_state`` for the retry.
437+
438+
Raises:
439+
RuntimeError: If the server returns an `InputRequiredResult` and
440+
``allow_input_required`` is ``False``.
397441
"""
442+
# TODO(L84): stop forwarding allow_input_required; run the MRTR auto-loop driver here (S6).
398443
return await self.session.call_tool(
399444
name=name,
400445
arguments=arguments,
401446
read_timeout_seconds=read_timeout_seconds,
402447
progress_callback=progress_callback,
448+
input_responses=input_responses,
449+
request_state=request_state,
403450
meta=meta,
451+
allow_input_required=allow_input_required,
404452
)
405453

406454
async def list_prompts(

src/mcp/client/session.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable, Mapping
55
from dataclasses import dataclass
66
from types import TracebackType
7-
from typing import Any, Protocol, cast
7+
from typing import Any, Literal, Protocol, cast, overload
88

99
import anyio
1010
import anyio.abc
@@ -173,6 +173,10 @@ async def _default_logging_callback(
173173

174174
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
175175

176+
_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter(
177+
types.CallToolResult | types.InputRequiredResult
178+
)
179+
176180

177181
class ClientSession:
178182
"""Client half of an MCP connection, running on a `Dispatcher`.
@@ -269,7 +273,7 @@ async def __aexit__(
269273
async def send_request(
270274
self,
271275
request: types.ClientRequest,
272-
result_type: type[ReceiveResultT],
276+
result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT],
273277
request_read_timeout_seconds: float | None = None,
274278
metadata: ClientMessageMetadata | None = None,
275279
progress_callback: ProgressFnT | None = None,
@@ -308,6 +312,8 @@ async def send_request(
308312
_methods.validate_server_result(method, version, raw)
309313
except KeyError:
310314
pass
315+
if isinstance(result_type, TypeAdapter):
316+
return result_type.validate_python(raw, by_name=False)
311317
return result_type.model_validate(raw, by_name=False)
312318

313319
async def send_notification(self, notification: types.ClientNotification) -> None:
@@ -596,29 +602,83 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
596602
types.EmptyResult,
597603
)
598604

605+
@overload
606+
async def call_tool(
607+
self,
608+
name: str,
609+
arguments: dict[str, Any] | None = None,
610+
read_timeout_seconds: float | None = None,
611+
progress_callback: ProgressFnT | None = None,
612+
*,
613+
input_responses: types.InputResponses | None = None,
614+
request_state: str | None = None,
615+
meta: RequestParamsMeta | None = None,
616+
allow_input_required: Literal[False] = False,
617+
) -> types.CallToolResult: ...
618+
619+
@overload
620+
async def call_tool(
621+
self,
622+
name: str,
623+
arguments: dict[str, Any] | None = None,
624+
read_timeout_seconds: float | None = None,
625+
progress_callback: ProgressFnT | None = None,
626+
*,
627+
input_responses: types.InputResponses | None = None,
628+
request_state: str | None = None,
629+
meta: RequestParamsMeta | None = None,
630+
allow_input_required: bool,
631+
) -> types.CallToolResult | types.InputRequiredResult: ...
632+
599633
async def call_tool(
600634
self,
601635
name: str,
602636
arguments: dict[str, Any] | None = None,
603637
read_timeout_seconds: float | None = None,
604638
progress_callback: ProgressFnT | None = None,
605639
*,
640+
input_responses: types.InputResponses | None = None,
641+
request_state: str | None = None,
606642
meta: RequestParamsMeta | None = None,
607-
) -> types.CallToolResult:
608-
"""Send a tools/call request with optional progress callback support."""
643+
allow_input_required: bool = False,
644+
) -> types.CallToolResult | types.InputRequiredResult:
645+
"""Send a tools/call request with optional progress callback support.
646+
647+
Args:
648+
input_responses: Responses to a prior `InputRequiredResult.input_requests`.
649+
request_state: Opaque state echoed from a prior `InputRequiredResult`.
650+
allow_input_required: When ``False`` (default), an `InputRequiredResult`
651+
from the server raises `RuntimeError`; when ``True``, it is returned
652+
so the caller can resolve the requests and retry.
653+
654+
Raises:
655+
RuntimeError: If the server returns an `InputRequiredResult` and
656+
``allow_input_required`` is ``False``.
657+
"""
609658

610659
result = await self.send_request(
611660
types.CallToolRequest(
612-
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta),
661+
params=types.CallToolRequestParams(
662+
name=name,
663+
arguments=arguments,
664+
input_responses=input_responses,
665+
request_state=request_state,
666+
_meta=meta,
667+
),
613668
),
614-
types.CallToolResult,
669+
_CallToolResultAdapter,
615670
request_read_timeout_seconds=read_timeout_seconds,
616671
progress_callback=progress_callback,
617672
)
618673

619-
if not result.is_error:
674+
if isinstance(result, types.CallToolResult) and not result.is_error:
620675
await self._validate_tool_result(name, result)
621676

677+
if isinstance(result, types.InputRequiredResult) and not allow_input_required:
678+
raise RuntimeError(
679+
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
680+
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
681+
)
622682
return result
623683

624684
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:

src/mcp/client/session_group.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from dataclasses import dataclass
1313
from types import TracebackType
14-
from typing import Any, TypeAlias
14+
from typing import Any, Literal, TypeAlias, overload
1515

1616
import anyio
1717
import httpx
@@ -190,24 +190,63 @@ def tools(self) -> dict[str, types.Tool]:
190190
"""Returns the tools as a dictionary of names to tools."""
191191
return self._tools
192192

193+
@overload
193194
async def call_tool(
194195
self,
195196
name: str,
196197
arguments: dict[str, Any] | None = None,
197198
read_timeout_seconds: float | None = None,
198199
progress_callback: ProgressFnT | None = None,
199200
*,
201+
input_responses: types.InputResponses | None = None,
202+
request_state: str | None = None,
200203
meta: types.RequestParamsMeta | None = None,
201-
) -> types.CallToolResult:
202-
"""Executes a tool given its name and arguments."""
204+
allow_input_required: Literal[False] = False,
205+
) -> types.CallToolResult: ...
206+
207+
@overload
208+
async def call_tool(
209+
self,
210+
name: str,
211+
arguments: dict[str, Any] | None = None,
212+
read_timeout_seconds: float | None = None,
213+
progress_callback: ProgressFnT | None = None,
214+
*,
215+
input_responses: types.InputResponses | None = None,
216+
request_state: str | None = None,
217+
meta: types.RequestParamsMeta | None = None,
218+
allow_input_required: bool,
219+
) -> types.CallToolResult | types.InputRequiredResult: ...
220+
221+
async def call_tool(
222+
self,
223+
name: str,
224+
arguments: dict[str, Any] | None = None,
225+
read_timeout_seconds: float | None = None,
226+
progress_callback: ProgressFnT | None = None,
227+
*,
228+
input_responses: types.InputResponses | None = None,
229+
request_state: str | None = None,
230+
meta: types.RequestParamsMeta | None = None,
231+
allow_input_required: bool = False,
232+
) -> types.CallToolResult | types.InputRequiredResult:
233+
"""Executes a tool given its name and arguments.
234+
235+
Raises:
236+
RuntimeError: If the server returns an `InputRequiredResult` and
237+
``allow_input_required`` is ``False``.
238+
"""
203239
session = self._tool_to_session[name]
204240
session_tool_name = self.tools[name].name
205241
return await session.call_tool(
206242
session_tool_name,
207243
arguments=arguments,
208244
read_timeout_seconds=read_timeout_seconds,
209245
progress_callback=progress_callback,
246+
input_responses=input_responses,
247+
request_state=request_state,
210248
meta=meta,
249+
allow_input_required=allow_input_required,
211250
)
212251

213252
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:

tests/client/test_session.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
from mcp import MCPError, types
1414
from mcp.client import ClientRequestContext
15+
from mcp.client.client import Client
1516
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
17+
from mcp.server import Server, ServerRequestContext
1618
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
1719
from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest
1820
from mcp.shared.message import SessionMessage
@@ -1656,3 +1658,59 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data()
16561658
await session.discover()
16571659
assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION
16581660
assert [m for m, _ in dispatcher.calls] == ["server/discover"]
1661+
1662+
1663+
@pytest.mark.anyio
1664+
async def test_call_tool_returns_input_required_result_when_server_requests_input() -> None:
1665+
# `on_call_tool` is still typed `-> CallToolResult` on this branch (#2967 widens it later);
1666+
# `add_request_handler` is `HandlerResult`-typed and accepts `InputRequiredResult` cleanly.
1667+
async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult:
1668+
return types.InputRequiredResult(request_state="s")
1669+
1670+
server = Server("test")
1671+
server.add_request_handler("tools/call", types.CallToolRequestParams, handler)
1672+
with anyio.fail_after(5):
1673+
async with Client(server, mode="2026-07-28") as client:
1674+
result = await client.call_tool("ask", allow_input_required=True)
1675+
assert isinstance(result, types.InputRequiredResult)
1676+
assert result.request_state == "s"
1677+
1678+
1679+
@pytest.mark.anyio
1680+
async def test_call_tool_threads_input_responses_and_request_state_into_params() -> None:
1681+
captured: list[types.CallToolRequestParams] = []
1682+
1683+
async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
1684+
captured.append(params)
1685+
return CallToolResult(content=[])
1686+
1687+
async def on_list_tools(
1688+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
1689+
) -> types.ListToolsResult:
1690+
return types.ListToolsResult(tools=[])
1691+
1692+
server = Server("test", on_call_tool=on_call_tool, on_list_tools=on_list_tools)
1693+
with anyio.fail_after(5):
1694+
async with Client(server, mode="2026-07-28") as client:
1695+
await client.call_tool(
1696+
"ask",
1697+
input_responses={"k": types.ElicitResult(action="decline")},
1698+
request_state="s",
1699+
)
1700+
assert captured[0].input_responses == {"k": types.ElicitResult(action="decline")}
1701+
assert captured[0].request_state == "s"
1702+
1703+
1704+
@pytest.mark.anyio
1705+
async def test_client_call_tool_raises_on_input_required_without_opt_in() -> None:
1706+
async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult:
1707+
return types.InputRequiredResult(request_state="s")
1708+
1709+
server = Server("test")
1710+
server.add_request_handler("tools/call", types.CallToolRequestParams, handler)
1711+
with anyio.fail_after(5):
1712+
async with Client(server, mode="2026-07-28") as client:
1713+
with pytest.raises(RuntimeError):
1714+
await client.call_tool("t")
1715+
result = await client.call_tool("t", allow_input_required=True)
1716+
assert isinstance(result, types.InputRequiredResult)

0 commit comments

Comments
 (0)