Skip to content

Commit e58ff02

Browse files
committed
Add resolver dependency injection for MCPServer tools
A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the resolver `fn` before the tool body, instead of by the calling LLM. Resolvers form a dependency graph: a resolver may declare its own `Resolve(...)` dependencies, read the `Context` (including the new `Context.headers`), and receive the tool's own arguments by name. A resolver may return `Elicit[T]` to ask the client; the SDK runs the elicitation and injects the answer. Each resolver runs at most once per `tools/call`. The injected type follows the consumer's annotation: the unwrapped model aborts the call on decline/cancel, while the elicitation result union lets the consumer branch on the outcome. Resolved parameters are omitted from the tool's input schema; unclassifiable resolver parameters and cyclic resolver dependencies raise at registration time.
1 parent ae13ede commit e58ff02

7 files changed

Lines changed: 663 additions & 8 deletions

File tree

docs/migration.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,64 @@ app = server.streamable_http_app(
13961396

13971397
The lowlevel `Server` also now exposes a `session_manager` property to access the `StreamableHTTPSessionManager` after calling `streamable_http_app()`.
13981398

1399+
### Resolver dependency injection for tools (`Resolve` / `Elicit`)
1400+
1401+
A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the resolver `fn` before the tool body, instead of by the calling LLM. Resolvers form a dependency graph: a resolver may declare its own `Resolve(...)` dependencies, read the `Context` (including `ctx.headers`), and receive the tool's own arguments by name. A resolver may return `Elicit[T]` to ask the client; the SDK runs the elicitation and injects the answer. Each resolver runs at most once per `tools/call`.
1402+
1403+
```python
1404+
from typing import Annotated
1405+
1406+
from pydantic import BaseModel
1407+
1408+
from mcp.server.mcpserver import AcceptedElicitation, Context, Elicit, MCPServer, Resolve
1409+
1410+
mcp = MCPServer(name="github")
1411+
1412+
1413+
class Login(BaseModel):
1414+
username: str
1415+
1416+
1417+
class Confirm(BaseModel):
1418+
ok: bool
1419+
1420+
1421+
async def login(ctx: Context) -> Login | Elicit[Login]:
1422+
if username := (ctx.headers or {}).get("x-github-user"):
1423+
return Login(username=username) # resolved from context, no question
1424+
return Elicit("GitHub username?", Login) # must ask
1425+
1426+
1427+
async def confirm(repo: str, login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]:
1428+
return Elicit(f"Star {repo} as {login.username}?", Confirm)
1429+
1430+
1431+
@mcp.tool()
1432+
async def star_repo(
1433+
repo: str,
1434+
login: Annotated[Login, Resolve(login)],
1435+
confirm: Annotated[Confirm, Resolve(confirm)],
1436+
) -> str:
1437+
"""Star a GitHub repo."""
1438+
return f"starred {repo} as {login.username}" if confirm.ok else "cancelled"
1439+
```
1440+
1441+
The injected type follows the consumer's annotation. Annotating the unwrapped model (`Annotated[Login, Resolve(login)]`) injects the model on accept and aborts the call with an error result on decline or cancel. To branch on the outcome instead, annotate the elicitation result union:
1442+
1443+
```python
1444+
@mcp.tool()
1445+
async def whoami(
1446+
login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)],
1447+
) -> str:
1448+
match login:
1449+
case AcceptedElicitation(data=data):
1450+
return f"hi {data.username}"
1451+
case _:
1452+
return "no username provided"
1453+
```
1454+
1455+
Resolved parameters are omitted from the tool's input schema, so the client never supplies them. Resolver parameters that cannot be classified, and cyclic resolver dependencies, raise at registration time.
1456+
13991457
## Need Help?
14001458

14011459
If you encounter issues during migration:

src/mcp/server/mcpserver/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,27 @@
33
from mcp.types import Icon
44

55
from .context import Context
6+
from .resolve import (
7+
AcceptedElicitation,
8+
CancelledElicitation,
9+
DeclinedElicitation,
10+
Elicit,
11+
ElicitationResult,
12+
Resolve,
13+
)
614
from .server import MCPServer
715
from .utilities.types import Audio, Image
816

9-
__all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"]
17+
__all__ = [
18+
"MCPServer",
19+
"Context",
20+
"Image",
21+
"Audio",
22+
"Icon",
23+
"Resolve",
24+
"Elicit",
25+
"ElicitationResult",
26+
"AcceptedElicitation",
27+
"DeclinedElicitation",
28+
"CancelledElicitation",
29+
]

src/mcp/server/mcpserver/context.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable
4-
from typing import TYPE_CHECKING, Any, Generic
3+
from collections.abc import Iterable, Mapping
4+
from typing import TYPE_CHECKING, Any, Generic, Protocol, cast
55

66
from pydantic import AnyUrl, BaseModel
77
from typing_extensions import deprecated
@@ -22,6 +22,11 @@
2222
from mcp.server.mcpserver.server import MCPServer
2323

2424

25+
class _HasHeaders(Protocol):
26+
@property
27+
def headers(self) -> Mapping[str, str]: ...
28+
29+
2530
class Context(BaseModel, Generic[LifespanContextT, RequestT]):
2631
"""Context object providing access to MCP capabilities.
2732
@@ -225,6 +230,17 @@ def client_id(self) -> str | None:
225230
"""
226231
return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover
227232

233+
@property
234+
def headers(self) -> Mapping[str, str] | None:
235+
"""Request headers carried by this message, when the transport has them.
236+
237+
Populated by HTTP-based transports; `None` on stdio.
238+
"""
239+
request = self.request_context.request
240+
if request is None:
241+
return None
242+
return cast("_HasHeaders", request).headers
243+
228244
@property
229245
def request_id(self) -> str:
230246
"""Get the unique ID for this request."""
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
"""Resolver dependency injection for MCPServer tools.
2+
3+
A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the
4+
resolver `fn` before the tool body, instead of from the LLM-supplied arguments.
5+
Resolvers form a DAG: a resolver may declare its own `Resolve(...)` dependencies,
6+
take tool arguments by name, and take the `Context`. A resolver may return
7+
`Elicit[T]` to ask the client; the framework runs the elicitation and injects the
8+
answer.
9+
10+
Whether the consumer receives the unwrapped model or the full
11+
`ElicitationResult` union is decided by the consumer's annotation:
12+
13+
- `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call.
14+
- `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the
15+
full outcome; the consumer branches on accept/decline/cancel.
16+
17+
Each resolver runs at most once per `tools/call` (memoized by function identity).
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import inspect
23+
import typing
24+
from collections.abc import Callable, Mapping
25+
from typing import Annotated, Any, Generic, cast, get_args, get_origin
26+
27+
import anyio.to_thread
28+
from pydantic import BaseModel
29+
from typing_extensions import TypeVar
30+
31+
from mcp.server.elicitation import (
32+
AcceptedElicitation,
33+
CancelledElicitation,
34+
DeclinedElicitation,
35+
ElicitationResult,
36+
)
37+
from mcp.server.mcpserver.context import Context
38+
from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError
39+
from mcp.shared._callable_inspection import is_async_callable
40+
41+
T = TypeVar("T", bound=BaseModel)
42+
43+
# The union members the framework injects when a consumer opts into the outcome.
44+
_ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation)
45+
46+
47+
class Resolve:
48+
"""Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`."""
49+
50+
def __init__(self, fn: Callable[..., Any]) -> None:
51+
self.fn = fn
52+
53+
54+
class Elicit(Generic[T]):
55+
"""A resolver's request to ask the client.
56+
57+
Returned from a resolver to signal that the value must be elicited. The
58+
framework runs `ctx.elicit(message, schema)` and injects the outcome.
59+
"""
60+
61+
def __init__(self, message: str, schema: type[T]) -> None:
62+
self.message = message
63+
self.schema = schema
64+
65+
66+
class _ParamPlan:
67+
"""How to fill one resolver parameter, decided once at registration."""
68+
69+
kind: str # "context" | "resolve" | "by_name"
70+
resolve: Resolve | None
71+
wants_union: bool
72+
73+
def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool = False) -> None:
74+
self.kind = kind
75+
self.resolve = resolve
76+
self.wants_union = wants_union
77+
78+
79+
class _ResolverPlan:
80+
"""A resolver's parameters and whether it is async, analyzed once."""
81+
82+
def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None:
83+
self.fn = fn
84+
self.params = params
85+
self.is_async = is_async
86+
87+
88+
def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]:
89+
"""Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`.
90+
91+
Returns a mapping of parameter name to `(Resolve, wants_union)`, where
92+
`wants_union` is True when the annotated type is an `ElicitationResult` member
93+
(the consumer wants the full outcome rather than the unwrapped model).
94+
"""
95+
hints = typing.get_type_hints(fn, include_extras=True)
96+
resolved: dict[str, tuple[Resolve, bool]] = {}
97+
for name, annotation in hints.items():
98+
if get_origin(annotation) is not Annotated:
99+
continue
100+
type_arg, *metadata = get_args(annotation)
101+
marker = next((m for m in metadata if isinstance(m, Resolve)), None)
102+
if marker is not None:
103+
resolved[name] = (marker, _wants_union(type_arg))
104+
return resolved
105+
106+
107+
def _wants_union(type_arg: Any) -> bool:
108+
"""True when `type_arg` is an `ElicitationResult` member (or a union of them)."""
109+
members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,)
110+
return any(isinstance(m, type) and issubclass(m, _ELICITATION_RESULT_MEMBERS) for m in members)
111+
112+
113+
def build_resolver_plans(
114+
resolved_params: Mapping[str, tuple[Resolve, bool]],
115+
tool_arg_names: set[str],
116+
) -> dict[int, _ResolverPlan]:
117+
"""Statically analyze the resolver DAG rooted at a tool's resolved parameters.
118+
119+
Raises:
120+
InvalidSignature: If a resolver has a cyclic dependency, or a resolver
121+
parameter cannot be classified (not a `Context`, a nested `Resolve`,
122+
or a tool argument by name).
123+
"""
124+
plans: dict[int, _ResolverPlan] = {}
125+
126+
def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None:
127+
key = id(fn)
128+
if key in stack:
129+
raise InvalidSignature(f"Resolver {fn.__name__!r} has a cyclic dependency")
130+
if key in plans:
131+
return
132+
133+
hints = typing.get_type_hints(fn, include_extras=True)
134+
sig = inspect.signature(fn)
135+
params: dict[str, _ParamPlan] = {}
136+
nested: list[Callable[..., Any]] = []
137+
for param_name in sig.parameters:
138+
annotation = hints.get(param_name)
139+
if annotation is not None and _is_context_annotation(annotation):
140+
params[param_name] = _ParamPlan("context")
141+
continue
142+
marker, wants_union = _resolve_marker(annotation)
143+
if marker is not None:
144+
params[param_name] = _ParamPlan("resolve", marker, wants_union)
145+
nested.append(marker.fn)
146+
continue
147+
if param_name in tool_arg_names:
148+
params[param_name] = _ParamPlan("by_name")
149+
continue
150+
raise InvalidSignature(
151+
f"Resolver {fn.__name__!r} parameter {param_name!r} cannot be resolved: "
152+
"expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name"
153+
)
154+
155+
plans[key] = _ResolverPlan(fn, params, is_async_callable(fn))
156+
for dep in nested:
157+
analyze(dep, stack + (key,))
158+
159+
for marker, _ in resolved_params.values():
160+
analyze(marker.fn, ())
161+
return plans
162+
163+
164+
def _resolve_marker(annotation: Any) -> tuple[Resolve | None, bool]:
165+
if get_origin(annotation) is not Annotated:
166+
return None, False
167+
type_arg, *metadata = get_args(annotation)
168+
marker = next((m for m in metadata if isinstance(m, Resolve)), None)
169+
return marker, (_wants_union(type_arg) if marker is not None else False)
170+
171+
172+
def _is_context_annotation(annotation: Any) -> bool:
173+
if get_origin(annotation) is Annotated:
174+
annotation = get_args(annotation)[0]
175+
return isinstance(annotation, type) and issubclass(annotation, Context)
176+
177+
178+
async def resolve_arguments(
179+
resolved_params: Mapping[str, tuple[Resolve, bool]],
180+
plans: Mapping[int, _ResolverPlan],
181+
tool_args: Mapping[str, Any],
182+
context: Context[Any, Any],
183+
) -> dict[str, Any]:
184+
"""Resolve every `Resolve`-marked tool parameter into a concrete value.
185+
186+
Each resolver runs at most once (memoized by function identity). Returns a
187+
mapping of tool parameter name to the value to inject.
188+
189+
Raises:
190+
ToolError: If an elicited value is declined or cancelled and the consumer
191+
asked for the unwrapped model (rather than the result union).
192+
"""
193+
cache: dict[int, ElicitationResult[BaseModel]] = {}
194+
injected: dict[str, Any] = {}
195+
for name, (marker, wants_union) in resolved_params.items():
196+
outcome = await _resolve(marker.fn, plans, tool_args, context, cache)
197+
injected[name] = outcome if wants_union else _unwrap(outcome, name)
198+
return injected
199+
200+
201+
async def _resolve(
202+
fn: Callable[..., Any],
203+
plans: Mapping[int, _ResolverPlan],
204+
tool_args: Mapping[str, Any],
205+
context: Context[Any, Any],
206+
cache: dict[int, ElicitationResult[BaseModel]],
207+
) -> ElicitationResult[BaseModel]:
208+
key = id(fn)
209+
if key in cache:
210+
return cache[key]
211+
212+
plan = plans[key]
213+
kwargs: dict[str, Any] = {}
214+
for param_name, param_plan in plan.params.items():
215+
if param_plan.kind == "context":
216+
kwargs[param_name] = context
217+
elif param_plan.kind == "by_name":
218+
kwargs[param_name] = tool_args[param_name]
219+
else:
220+
assert param_plan.resolve is not None
221+
dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache)
222+
kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name)
223+
224+
if plan.is_async:
225+
result = await fn(**kwargs)
226+
else:
227+
result = await anyio.to_thread.run_sync(lambda: fn(**kwargs))
228+
229+
outcome: ElicitationResult[BaseModel]
230+
if isinstance(result, Elicit):
231+
elicit = cast("Elicit[BaseModel]", result)
232+
outcome = await context.elicit(elicit.message, elicit.schema)
233+
else:
234+
outcome = AcceptedElicitation(data=result)
235+
236+
cache[key] = outcome
237+
return outcome
238+
239+
240+
def _unwrap(outcome: ElicitationResult[BaseModel], name: str) -> BaseModel:
241+
if isinstance(outcome, AcceptedElicitation):
242+
return outcome.data
243+
raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}")
244+
245+
246+
__all__ = [
247+
"Resolve",
248+
"Elicit",
249+
"ElicitationResult",
250+
"AcceptedElicitation",
251+
"DeclinedElicitation",
252+
"CancelledElicitation",
253+
"find_resolved_parameters",
254+
"build_resolver_plans",
255+
"resolve_arguments",
256+
]

0 commit comments

Comments
 (0)