Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,37 +343,41 @@ async def _aggregate_components(self, server_info: types.Implementation, session
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None

# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")
if capabilities is not None and capabilities.prompts is not None:
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")

# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")
if capabilities is not None and capabilities.resources is not None:
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")

# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")
if capabilities is not None and capabilities.tools is not None:
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")

# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
Expand Down
92 changes: 92 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def mock_exit_stack():
return mock.MagicMock(spec=contextlib.AsyncExitStack)


def _initialize_result(capabilities: types.ServerCapabilities) -> types.InitializeResult:
return types.InitializeResult(
protocol_version=types.LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
server_info=types.Implementation(name="TestServer", version="1.0"),
)


def test_client_session_group_init():
mcp_session_group = ClientSessionGroup()
assert not mcp_session_group._tools
Expand Down Expand Up @@ -99,6 +107,13 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
mock_resource1.name = "resource_b"
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prompt_c"
mock_session.initialize_result = _initialize_result(
types.ServerCapabilities(
prompts=types.PromptsCapability(),
resources=types.ResourcesCapability(),
tools=types.ToolsCapability(),
)
)
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
Expand All @@ -125,6 +140,69 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
mock_session.list_prompts.assert_awaited_once()


@pytest.mark.anyio
async def test_client_session_group_skips_unadvertised_capabilities(
mock_exit_stack: contextlib.AsyncExitStack,
caplog: pytest.LogCaptureFixture,
):
"""Only query component lists for capabilities advertised by initialize."""
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "ToolsOnlyServer"
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_tool = mock.Mock(spec=types.Tool)
mock_tool.name = "tool_a"
mock_session.initialize_result = _initialize_result(
types.ServerCapabilities(
prompts=None,
resources=None,
tools=types.ToolsCapability(),
)
)
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])

group = ClientSessionGroup(exit_stack=mock_exit_stack)
with caplog.at_level("WARNING"):
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))

assert group.prompts == {}
assert group.resources == {}
assert group.tools == {"tool_a": mock_tool}
mock_session.list_tools.assert_awaited_once()
mock_session.list_prompts.assert_not_awaited()
mock_session.list_resources.assert_not_awaited()
assert "Could not fetch prompts" not in caplog.text
assert "Could not fetch resources" not in caplog.text


@pytest.mark.anyio
async def test_client_session_group_skips_unadvertised_tools(mock_exit_stack: contextlib.AsyncExitStack):
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "PromptsOnlyServer"
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_prompt = mock.Mock(spec=types.Prompt)
mock_prompt.name = "prompt_a"
mock_session.initialize_result = _initialize_result(
types.ServerCapabilities(
prompts=types.PromptsCapability(),
resources=None,
tools=None,
)
)
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt])

group = ClientSessionGroup(exit_stack=mock_exit_stack)
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))

assert group.prompts == {"prompt_a": mock_prompt}
assert group.resources == {}
assert group.tools == {}
mock_session.list_prompts.assert_awaited_once()
mock_session.list_resources.assert_not_awaited()
mock_session.list_tools.assert_not_awaited()


@pytest.mark.anyio
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
"""Test connecting with a component name hook."""
Expand All @@ -134,6 +212,13 @@ async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_s
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_tool = mock.Mock(spec=types.Tool)
mock_tool.name = "base_tool"
mock_session.initialize_result = _initialize_result(
types.ServerCapabilities(
prompts=types.PromptsCapability(),
resources=types.ResourcesCapability(),
tools=types.ToolsCapability(),
)
)
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
Expand Down Expand Up @@ -245,6 +330,13 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro
# Configure the new session to return a tool with the *same name*
duplicate_tool = mock.Mock(spec=types.Tool)
duplicate_tool.name = existing_tool_name
mock_session_new.initialize_result = _initialize_result(
types.ServerCapabilities(
prompts=types.PromptsCapability(),
resources=types.ResourcesCapability(),
tools=types.ToolsCapability(),
)
)
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
# Keep other lists empty for simplicity
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
Expand Down
Loading