From 5f0974f7be910d5490256c11ecf96e280df019f7 Mon Sep 17 00:00:00 2001 From: Sergey Volkov Date: Thu, 25 Jun 2026 04:55:43 +0800 Subject: [PATCH] Fix SSE gateway endpoint resolution --- src/mcp/client/sse.py | 58 ++++++++++++++++++++++- tests/shared/test_sse.py | 99 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 6a2579f4c0..1f5c4b56ab 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -2,7 +2,7 @@ from collections.abc import Callable from contextlib import asynccontextmanager from typing import Any -from urllib.parse import parse_qs, urljoin, urlparse +from urllib.parse import parse_qs, urljoin, urlparse, urlunparse import anyio import httpx @@ -27,6 +27,60 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None: return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0] +def _path_segments(path: str) -> list[str]: + return [segment for segment in path.split("/") if segment] + + +def _resolve_prefixed_path(sse_path: str, endpoint_path: str) -> str: + sse_dir = sse_path.rstrip("/").rsplit("/", 1)[0] + if not sse_dir: + return endpoint_path + + sse_segments = _path_segments(sse_dir) + endpoint_segments = _path_segments(endpoint_path) + + # The backend may emit its route path (for example /v1/messages) while the + # public SSE URL includes a gateway prefix (/gateway/deployment/v1/sse). + # Only infer a prefix when the paths overlap; no-overlap absolute paths are + # ambiguous and retain normal origin-rooted URL semantics. + overlap = 0 + for overlap_size in range(min(len(sse_segments), len(endpoint_segments)), 0, -1): + if sse_segments[-overlap_size:] == endpoint_segments[:overlap_size]: + overlap = overlap_size + break + + if overlap == 0: + return endpoint_path + + prefix_segments = sse_segments[: len(sse_segments) - overlap] + if not prefix_segments: + return endpoint_path + + return f"/{'/'.join(prefix_segments)}{endpoint_path}" + + +def _resolve_endpoint_url(url: str, endpoint: str) -> str: + endpoint_parsed = urlparse(endpoint) + if endpoint_parsed.scheme or endpoint_parsed.netloc: + return urljoin(url, endpoint) + + if not endpoint_parsed.path.startswith("/"): + return urljoin(url, endpoint) + + url_parsed = urlparse(url) + endpoint_path = _resolve_prefixed_path(url_parsed.path, endpoint_parsed.path) + return urlunparse( + ( + url_parsed.scheme, + url_parsed.netloc, + endpoint_path, + endpoint_parsed.params, + endpoint_parsed.query, + endpoint_parsed.fragment, + ) + ) + + @asynccontextmanager async def sse_client( url: str, @@ -68,7 +122,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - endpoint_url = urljoin(url, sse.data) + endpoint_url = _resolve_endpoint_url(url, sse.data) logger.debug(f"Received endpoint URL: {endpoint_url}") url_parsed = urlparse(url) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 675a4acb16..746e0f3a16 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -2,7 +2,7 @@ import json from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse @@ -19,12 +19,13 @@ import mcp.client.sse from mcp import types from mcp.client.session import ClientSession -from mcp.client.sse import _extract_session_id_from_endpoint, sse_client +from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import MCPError +from mcp.shared.message import SessionMessage from mcp.types import ( CallToolRequestParams, CallToolResult, @@ -173,6 +174,100 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non assert _extract_session_id_from_endpoint(endpoint_url) == expected +@pytest.mark.parametrize( + ("sse_url", "endpoint", "expected_url"), + [ + ( + "https://example.com/gateway/deployment/v1/sse", + "/v1/messages/?session_id=abc123", + "https://example.com/gateway/deployment/v1/messages/?session_id=abc123", + ), + ( + "https://example.com/gateway/deployment/v1/sse", + "/gateway/deployment/v1/messages/?session_id=abc123", + "https://example.com/gateway/deployment/v1/messages/?session_id=abc123", + ), + ( + "https://example.com/gateway/deployment/sse", + "/messages/?session_id=abc123", + "https://example.com/messages/?session_id=abc123", + ), + ( + "https://example.com/sse", + "/messages/?session_id=abc123", + "https://example.com/messages/?session_id=abc123", + ), + ( + "https://example.com/gateway/sse", + "/", + "https://example.com/", + ), + ( + "https://example.com/gateway/deployment/v1/sse", + "messages/?session_id=abc123", + "https://example.com/gateway/deployment/v1/messages/?session_id=abc123", + ), + ( + "https://example.com/gateway/deployment/v1/sse", + "https://example.com/messages/?session_id=abc123", + "https://example.com/messages/?session_id=abc123", + ), + ], +) +def test_resolve_endpoint_url_preserves_gateway_path_prefix(sse_url: str, endpoint: str, expected_url: str) -> None: + assert _resolve_endpoint_url(sse_url, endpoint) == expected_url + + +@pytest.mark.anyio +async def test_sse_client_posts_to_endpoint_with_gateway_path_prefix() -> None: + """A gateway prefix on the public SSE URL is preserved for absolute-path endpoint events.""" + posted = anyio.Event() + posted_urls: list[str] = [] + + async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: + yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123") + await anyio.sleep_forever() + + mock_event_source = MagicMock() + mock_event_source.aiter_sse.return_value = mock_aiter_sse() + mock_event_source.response = MagicMock() + mock_event_source.response.raise_for_status = MagicMock() + + mock_aconnect_sse = MagicMock() + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) + + async def mock_post(url: str, **kwargs: Any) -> MagicMock: + posted_urls.append(url) + posted.set() + return MagicMock(status_code=200, raise_for_status=MagicMock()) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.post = AsyncMock(side_effect=mock_post) + + def mock_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return cast(httpx.AsyncClient, mock_client) + + with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse): + async with sse_client( + "http://test/gateway/deployment/v1/sse", + httpx_client_factory=mock_factory, + ) as (_, write_stream): + request = types.JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + await write_stream.send(SessionMessage(request)) + + with anyio.fail_after(5): # pragma: no branch + await posted.wait() + + assert posted_urls == ["http://test/gateway/deployment/v1/messages/?session_id=abc123"] + + @pytest.mark.anyio async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: """No session-created callback fires when the endpoint URL carries no session ID."""