From 8de5aa8c3f4fd5619d50d45cdfaf00c0b5904588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=AF=E5=9F=BA=E9=AD=81?= <1412414664@qq.com> Date: Tue, 23 Jun 2026 23:00:58 +0800 Subject: [PATCH] fix: gate session group lists on capabilities --- src/mcp/client/session_group.py | 54 ++++++++++-------- tests/client/test_session_group.py | 92 ++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 25 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..be8dc4b59d 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -343,37 +343,41 @@ async def _aggregate_components(self, server_info: types.Implementation, session resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities is not None and capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities is not None and capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") + if capabilities is not None and capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") # Clean up exit stack for session if we couldn't retrieve anything # from the server. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..12fc726d7a 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -25,6 +25,14 @@ def mock_exit_stack(): return mock.MagicMock(spec=contextlib.AsyncExitStack) +def _initialize_result(capabilities: types.ServerCapabilities) -> types.InitializeResult: + return types.InitializeResult( + protocol_version=types.LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + server_info=types.Implementation(name="TestServer", version="1.0"), + ) + + def test_client_session_group_init(): mcp_session_group = ClientSessionGroup() assert not mcp_session_group._tools @@ -99,6 +107,13 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_resource1.name = "resource_b" mock_prompt1 = mock.Mock(spec=types.Prompt) mock_prompt1.name = "prompt_c" + mock_session.initialize_result = _initialize_result( + types.ServerCapabilities( + prompts=types.PromptsCapability(), + resources=types.ResourcesCapability(), + tools=types.ToolsCapability(), + ) + ) mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) @@ -125,6 +140,69 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_prompts.assert_awaited_once() +@pytest.mark.anyio +async def test_client_session_group_skips_unadvertised_capabilities( + mock_exit_stack: contextlib.AsyncExitStack, + caplog: pytest.LogCaptureFixture, +): + """Only query component lists for capabilities advertised by initialize.""" + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "ToolsOnlyServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "tool_a" + mock_session.initialize_result = _initialize_result( + types.ServerCapabilities( + prompts=None, + resources=None, + tools=types.ToolsCapability(), + ) + ) + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with caplog.at_level("WARNING"): + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + assert group.prompts == {} + assert group.resources == {} + assert group.tools == {"tool_a": mock_tool} + mock_session.list_tools.assert_awaited_once() + mock_session.list_prompts.assert_not_awaited() + mock_session.list_resources.assert_not_awaited() + assert "Could not fetch prompts" not in caplog.text + assert "Could not fetch resources" not in caplog.text + + +@pytest.mark.anyio +async def test_client_session_group_skips_unadvertised_tools(mock_exit_stack: contextlib.AsyncExitStack): + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "PromptsOnlyServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_prompt = mock.Mock(spec=types.Prompt) + mock_prompt.name = "prompt_a" + mock_session.initialize_result = _initialize_result( + types.ServerCapabilities( + prompts=types.PromptsCapability(), + resources=None, + tools=None, + ) + ) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt]) + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + assert group.prompts == {"prompt_a": mock_prompt} + assert group.resources == {} + assert group.tools == {} + mock_session.list_prompts.assert_awaited_once() + mock_session.list_resources.assert_not_awaited() + mock_session.list_tools.assert_not_awaited() + + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook.""" @@ -134,6 +212,13 @@ async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_s mock_session = mock.AsyncMock(spec=mcp.ClientSession) mock_tool = mock.Mock(spec=types.Tool) mock_tool.name = "base_tool" + mock_session.initialize_result = _initialize_result( + types.ServerCapabilities( + prompts=types.PromptsCapability(), + resources=types.ResourcesCapability(), + tools=types.ToolsCapability(), + ) + ) mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) @@ -245,6 +330,13 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro # Configure the new session to return a tool with the *same name* duplicate_tool = mock.Mock(spec=types.Tool) duplicate_tool.name = existing_tool_name + mock_session_new.initialize_result = _initialize_result( + types.ServerCapabilities( + prompts=types.PromptsCapability(), + resources=types.ResourcesCapability(), + tools=types.ToolsCapability(), + ) + ) mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) # Keep other lists empty for simplicity mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])