|
4 | 4 | from collections.abc import Callable, Mapping |
5 | 5 | from dataclasses import dataclass |
6 | 6 | from types import TracebackType |
7 | | -from typing import Any, Protocol, cast |
| 7 | +from typing import Any, Literal, Protocol, cast, overload |
8 | 8 |
|
9 | 9 | import anyio |
10 | 10 | import anyio.abc |
@@ -173,6 +173,10 @@ async def _default_logging_callback( |
173 | 173 |
|
174 | 174 | ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) |
175 | 175 |
|
| 176 | +_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter( |
| 177 | + types.CallToolResult | types.InputRequiredResult |
| 178 | +) |
| 179 | + |
176 | 180 |
|
177 | 181 | class ClientSession: |
178 | 182 | """Client half of an MCP connection, running on a `Dispatcher`. |
@@ -269,7 +273,7 @@ async def __aexit__( |
269 | 273 | async def send_request( |
270 | 274 | self, |
271 | 275 | request: types.ClientRequest, |
272 | | - result_type: type[ReceiveResultT], |
| 276 | + result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT], |
273 | 277 | request_read_timeout_seconds: float | None = None, |
274 | 278 | metadata: ClientMessageMetadata | None = None, |
275 | 279 | progress_callback: ProgressFnT | None = None, |
@@ -308,6 +312,8 @@ async def send_request( |
308 | 312 | _methods.validate_server_result(method, version, raw) |
309 | 313 | except KeyError: |
310 | 314 | pass |
| 315 | + if isinstance(result_type, TypeAdapter): |
| 316 | + return result_type.validate_python(raw, by_name=False) |
311 | 317 | return result_type.model_validate(raw, by_name=False) |
312 | 318 |
|
313 | 319 | async def send_notification(self, notification: types.ClientNotification) -> None: |
@@ -596,29 +602,83 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None |
596 | 602 | types.EmptyResult, |
597 | 603 | ) |
598 | 604 |
|
| 605 | + @overload |
| 606 | + async def call_tool( |
| 607 | + self, |
| 608 | + name: str, |
| 609 | + arguments: dict[str, Any] | None = None, |
| 610 | + read_timeout_seconds: float | None = None, |
| 611 | + progress_callback: ProgressFnT | None = None, |
| 612 | + *, |
| 613 | + input_responses: types.InputResponses | None = None, |
| 614 | + request_state: str | None = None, |
| 615 | + meta: RequestParamsMeta | None = None, |
| 616 | + allow_input_required: Literal[False] = False, |
| 617 | + ) -> types.CallToolResult: ... |
| 618 | + |
| 619 | + @overload |
| 620 | + async def call_tool( |
| 621 | + self, |
| 622 | + name: str, |
| 623 | + arguments: dict[str, Any] | None = None, |
| 624 | + read_timeout_seconds: float | None = None, |
| 625 | + progress_callback: ProgressFnT | None = None, |
| 626 | + *, |
| 627 | + input_responses: types.InputResponses | None = None, |
| 628 | + request_state: str | None = None, |
| 629 | + meta: RequestParamsMeta | None = None, |
| 630 | + allow_input_required: bool, |
| 631 | + ) -> types.CallToolResult | types.InputRequiredResult: ... |
| 632 | + |
599 | 633 | async def call_tool( |
600 | 634 | self, |
601 | 635 | name: str, |
602 | 636 | arguments: dict[str, Any] | None = None, |
603 | 637 | read_timeout_seconds: float | None = None, |
604 | 638 | progress_callback: ProgressFnT | None = None, |
605 | 639 | *, |
| 640 | + input_responses: types.InputResponses | None = None, |
| 641 | + request_state: str | None = None, |
606 | 642 | meta: RequestParamsMeta | None = None, |
607 | | - ) -> types.CallToolResult: |
608 | | - """Send a tools/call request with optional progress callback support.""" |
| 643 | + allow_input_required: bool = False, |
| 644 | + ) -> types.CallToolResult | types.InputRequiredResult: |
| 645 | + """Send a tools/call request with optional progress callback support. |
| 646 | +
|
| 647 | + Args: |
| 648 | + input_responses: Responses to a prior `InputRequiredResult.input_requests`. |
| 649 | + request_state: Opaque state echoed from a prior `InputRequiredResult`. |
| 650 | + allow_input_required: When ``False`` (default), an `InputRequiredResult` |
| 651 | + from the server raises `RuntimeError`; when ``True``, it is returned |
| 652 | + so the caller can resolve the requests and retry. |
| 653 | +
|
| 654 | + Raises: |
| 655 | + RuntimeError: If the server returns an `InputRequiredResult` and |
| 656 | + ``allow_input_required`` is ``False``. |
| 657 | + """ |
609 | 658 |
|
610 | 659 | result = await self.send_request( |
611 | 660 | types.CallToolRequest( |
612 | | - params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta), |
| 661 | + params=types.CallToolRequestParams( |
| 662 | + name=name, |
| 663 | + arguments=arguments, |
| 664 | + input_responses=input_responses, |
| 665 | + request_state=request_state, |
| 666 | + _meta=meta, |
| 667 | + ), |
613 | 668 | ), |
614 | | - types.CallToolResult, |
| 669 | + _CallToolResultAdapter, |
615 | 670 | request_read_timeout_seconds=read_timeout_seconds, |
616 | 671 | progress_callback=progress_callback, |
617 | 672 | ) |
618 | 673 |
|
619 | | - if not result.is_error: |
| 674 | + if isinstance(result, types.CallToolResult) and not result.is_error: |
620 | 675 | await self._validate_tool_result(name, result) |
621 | 676 |
|
| 677 | + if isinstance(result, types.InputRequiredResult) and not allow_input_required: |
| 678 | + raise RuntimeError( |
| 679 | + "Server returned InputRequiredResult; pass allow_input_required=True to receive it " |
| 680 | + "and retry call_tool(..., input_responses=..., request_state=result.request_state)." |
| 681 | + ) |
622 | 682 | return result |
623 | 683 |
|
624 | 684 | async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: |
|
0 commit comments