Skip to content

Commit 5f7cb64

Browse files
author
冯基魁
committed
fix: gate session group lists on capabilities
1 parent a527142 commit 5f7cb64

2 files changed

Lines changed: 93 additions & 25 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -343,37 +343,41 @@ async def _aggregate_components(self, server_info: types.Implementation, session
343343
resources_temp: dict[str, types.Resource] = {}
344344
tools_temp: dict[str, types.Tool] = {}
345345
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
346+
capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None
346347

347348
# Query the server for its prompts and aggregate to list.
348-
try:
349-
prompts = (await session.list_prompts()).prompts
350-
for prompt in prompts:
351-
name = self._component_name(prompt.name, server_info)
352-
prompts_temp[name] = prompt
353-
component_names.prompts.add(name)
354-
except MCPError as err: # pragma: no cover
355-
logging.warning(f"Could not fetch prompts: {err}")
349+
if capabilities is not None and capabilities.prompts is not None:
350+
try:
351+
prompts = (await session.list_prompts()).prompts
352+
for prompt in prompts:
353+
name = self._component_name(prompt.name, server_info)
354+
prompts_temp[name] = prompt
355+
component_names.prompts.add(name)
356+
except MCPError as err: # pragma: no cover
357+
logging.warning(f"Could not fetch prompts: {err}")
356358

357359
# Query the server for its resources and aggregate to list.
358-
try:
359-
resources = (await session.list_resources()).resources
360-
for resource in resources:
361-
name = self._component_name(resource.name, server_info)
362-
resources_temp[name] = resource
363-
component_names.resources.add(name)
364-
except MCPError as err: # pragma: no cover
365-
logging.warning(f"Could not fetch resources: {err}")
360+
if capabilities is not None and capabilities.resources is not None:
361+
try:
362+
resources = (await session.list_resources()).resources
363+
for resource in resources:
364+
name = self._component_name(resource.name, server_info)
365+
resources_temp[name] = resource
366+
component_names.resources.add(name)
367+
except MCPError as err: # pragma: no cover
368+
logging.warning(f"Could not fetch resources: {err}")
366369

367370
# Query the server for its tools and aggregate to list.
368-
try:
369-
tools = (await session.list_tools()).tools
370-
for tool in tools:
371-
name = self._component_name(tool.name, server_info)
372-
tools_temp[name] = tool
373-
tool_to_session_temp[name] = session
374-
component_names.tools.add(name)
375-
except MCPError as err: # pragma: no cover
376-
logging.warning(f"Could not fetch tools: {err}")
371+
if capabilities is not None and capabilities.tools is not None:
372+
try:
373+
tools = (await session.list_tools()).tools
374+
for tool in tools:
375+
name = self._component_name(tool.name, server_info)
376+
tools_temp[name] = tool
377+
tool_to_session_temp[name] = session
378+
component_names.tools.add(name)
379+
except MCPError as err: # pragma: no cover
380+
logging.warning(f"Could not fetch tools: {err}")
377381

378382
# Clean up exit stack for session if we couldn't retrieve anything
379383
# from the server.

tests/client/test_session_group.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def mock_exit_stack():
2525
return mock.MagicMock(spec=contextlib.AsyncExitStack)
2626

2727

28+
def _initialize_result(capabilities: types.ServerCapabilities) -> types.InitializeResult:
29+
return types.InitializeResult(
30+
protocol_version=types.LATEST_PROTOCOL_VERSION,
31+
capabilities=capabilities,
32+
server_info=types.Implementation(name="TestServer", version="1.0"),
33+
)
34+
35+
2836
def test_client_session_group_init():
2937
mcp_session_group = ClientSessionGroup()
3038
assert not mcp_session_group._tools
@@ -99,6 +107,13 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
99107
mock_resource1.name = "resource_b"
100108
mock_prompt1 = mock.Mock(spec=types.Prompt)
101109
mock_prompt1.name = "prompt_c"
110+
mock_session.initialize_result = _initialize_result(
111+
types.ServerCapabilities(
112+
prompts=types.PromptsCapability(),
113+
resources=types.ResourcesCapability(),
114+
tools=types.ToolsCapability(),
115+
)
116+
)
102117
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
103118
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
104119
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
@@ -125,6 +140,41 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
125140
mock_session.list_prompts.assert_awaited_once()
126141

127142

143+
@pytest.mark.anyio
144+
async def test_client_session_group_skips_unadvertised_capabilities(
145+
mock_exit_stack: contextlib.AsyncExitStack,
146+
caplog: pytest.LogCaptureFixture,
147+
):
148+
"""Only query component lists for capabilities advertised by initialize."""
149+
mock_server_info = mock.Mock(spec=types.Implementation)
150+
mock_server_info.name = "ToolsOnlyServer"
151+
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
152+
mock_tool = mock.Mock(spec=types.Tool)
153+
mock_tool.name = "tool_a"
154+
mock_session.initialize_result = _initialize_result(
155+
types.ServerCapabilities(
156+
prompts=None,
157+
resources=None,
158+
tools=types.ToolsCapability(),
159+
)
160+
)
161+
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
162+
163+
group = ClientSessionGroup(exit_stack=mock_exit_stack)
164+
with caplog.at_level("WARNING"):
165+
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
166+
await group.connect_to_server(StdioServerParameters(command="test"))
167+
168+
assert group.prompts == {}
169+
assert group.resources == {}
170+
assert group.tools == {"tool_a": mock_tool}
171+
mock_session.list_tools.assert_awaited_once()
172+
mock_session.list_prompts.assert_not_awaited()
173+
mock_session.list_resources.assert_not_awaited()
174+
assert "Could not fetch prompts" not in caplog.text
175+
assert "Could not fetch resources" not in caplog.text
176+
177+
128178
@pytest.mark.anyio
129179
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
130180
"""Test connecting with a component name hook."""
@@ -134,6 +184,13 @@ async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_s
134184
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
135185
mock_tool = mock.Mock(spec=types.Tool)
136186
mock_tool.name = "base_tool"
187+
mock_session.initialize_result = _initialize_result(
188+
types.ServerCapabilities(
189+
prompts=types.PromptsCapability(),
190+
resources=types.ResourcesCapability(),
191+
tools=types.ToolsCapability(),
192+
)
193+
)
137194
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
138195
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
139196
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
@@ -245,6 +302,13 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro
245302
# Configure the new session to return a tool with the *same name*
246303
duplicate_tool = mock.Mock(spec=types.Tool)
247304
duplicate_tool.name = existing_tool_name
305+
mock_session_new.initialize_result = _initialize_result(
306+
types.ServerCapabilities(
307+
prompts=types.PromptsCapability(),
308+
resources=types.ResourcesCapability(),
309+
tools=types.ToolsCapability(),
310+
)
311+
)
248312
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
249313
# Keep other lists empty for simplicity
250314
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])

0 commit comments

Comments
 (0)