Skip to content

Commit c3ea531

Browse files
committed
Address cubic review: by-name aliasing, return-annotation, callable-resolver naming
- tools/base.py: build tool_arg_names as 'alias or field_name' to match the runtime kwarg keys, so a by-name resolver param on an aliased field resolves instead of raising KeyError at call time. - resolve.py: iterate inspect.signature params (not get_type_hints items, which include 'return') so a Resolve marker on a return annotation is ignored; add _resolver_name so callable-object resolvers raise InvalidSignature instead of AttributeError in error messages. - migration.md: import DeclinedElicitation/CancelledElicitation used in the branching example so the snippet is runnable. Add regression tests for each.
1 parent 9e9282a commit c3ea531

4 files changed

Lines changed: 64 additions & 7 deletions

File tree

docs/migration.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,15 @@ from typing import Annotated
14351435

14361436
from pydantic import BaseModel
14371437

1438-
from mcp.server.mcpserver import AcceptedElicitation, Context, Elicit, MCPServer, Resolve
1438+
from mcp.server.mcpserver import (
1439+
AcceptedElicitation,
1440+
CancelledElicitation,
1441+
Context,
1442+
DeclinedElicitation,
1443+
Elicit,
1444+
MCPServer,
1445+
Resolve,
1446+
)
14391447

14401448
mcp = MCPServer(name="github")
14411449

src/mcp/server/mcpserver/resolve.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,22 @@ def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]:
100100
return {}
101101

102102

103+
def _resolver_name(fn: Callable[..., Any]) -> str:
104+
"""Best-effort display name for error messages (callable objects lack `__name__`)."""
105+
return getattr(fn, "__name__", None) or type(fn).__name__
106+
107+
103108
def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]:
104109
"""Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`.
105110
106111
Returns a mapping of parameter name to `(Resolve, wants_union)`, where
107112
`wants_union` is True when the annotated type is an `ElicitationResult` member
108113
(the consumer wants the full outcome rather than the unwrapped model).
109114
"""
115+
hints = _type_hints(fn)
110116
resolved: dict[str, tuple[Resolve, bool]] = {}
111-
for name, annotation in _type_hints(fn).items():
117+
for name in inspect.signature(fn).parameters:
118+
annotation = hints.get(name)
112119
if get_origin(annotation) is not Annotated:
113120
continue
114121
type_arg, *metadata = get_args(annotation)
@@ -140,7 +147,7 @@ def build_resolver_plans(
140147
def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None:
141148
key = id(fn)
142149
if key in stack:
143-
raise InvalidSignature(f"Resolver {fn.__name__!r} has a cyclic dependency")
150+
raise InvalidSignature(f"Resolver {_resolver_name(fn)!r} has a cyclic dependency")
144151
if key in plans:
145152
return
146153

@@ -162,7 +169,7 @@ def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None:
162169
params[param_name] = _ParamPlan("by_name")
163170
continue
164171
raise InvalidSignature(
165-
f"Resolver {fn.__name__!r} parameter {param_name!r} cannot be resolved: "
172+
f"Resolver {_resolver_name(fn)!r} parameter {param_name!r} cannot be resolved: "
166173
"expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name"
167174
)
168175

src/mcp/server/mcpserver/tools/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def from_function(
9292
)
9393
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
9494

95-
tool_arg_names = set(func_arg_metadata.arg_model.model_fields) | {
96-
field.alias for field in func_arg_metadata.arg_model.model_fields.values() if field.alias
97-
}
95+
# Match `model_dump_one_level`'s kwarg keys (alias when present, else field name)
96+
# so a by-name resolver param resolves to a key that exists at call time.
97+
tool_arg_names = {field.alias or name for name, field in func_arg_metadata.arg_model.model_fields.items()}
9898
resolver_plans = build_resolver_plans(resolved_params, tool_arg_names)
9999

100100
return cls(

tests/server/mcpserver/test_resolve.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,45 @@ async def tool(login: Annotated[Login, Resolve(login)]) -> str:
272272

273273
with pytest.raises(InvalidSignature, match="cannot be resolved"):
274274
Tool.from_function(tool)
275+
276+
277+
def test_resolve_marker_on_return_annotation_is_ignored():
278+
async def login(ctx: Context) -> Login:
279+
return Login(username="x") # pragma: no cover
280+
281+
async def tool(repo: str) -> Annotated[str, Resolve(login)]:
282+
return repo # pragma: no cover
283+
284+
assert find_resolved_parameters(tool) == {}
285+
286+
287+
def test_callable_object_resolver_error_uses_type_name():
288+
class BadResolver:
289+
async def __call__(self, mystery: int) -> Login:
290+
return Login(username="x") # pragma: no cover
291+
292+
async def tool(login: Annotated[Login, Resolve(BadResolver())]) -> str:
293+
return login.username # pragma: no cover
294+
295+
with pytest.raises(InvalidSignature, match="'BadResolver'"):
296+
Tool.from_function(tool)
297+
298+
299+
@pytest.mark.anyio
300+
async def test_by_name_resolver_param_uses_aliased_tool_arg():
301+
mcp = MCPServer(name="Aliased")
302+
303+
# `schema` collides with a BaseModel attribute, so func_metadata aliases the field;
304+
# the runtime kwarg key is the alias, which is what a by-name resolver must match.
305+
async def upper(schema: str) -> Login:
306+
return Login(username=schema.upper())
307+
308+
@mcp.tool()
309+
async def run(schema: str, shouted: Annotated[Login, Resolve(upper)]) -> str:
310+
return shouted.username
311+
312+
async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover
313+
raise AssertionError("should not elicit")
314+
315+
async with Client(mcp, mode="legacy", elicitation_callback=never) as client:
316+
assert await _text(client, "run", {"schema": "gpt"}) == "GPT"

0 commit comments

Comments
 (0)