Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse
from urllib.parse import parse_qs, urljoin, urlparse, urlunparse

import anyio
import httpx
Expand All @@ -27,6 +27,60 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]


def _path_segments(path: str) -> list[str]:
return [segment for segment in path.split("/") if segment]


def _resolve_prefixed_path(sse_path: str, endpoint_path: str) -> str:
sse_dir = sse_path.rstrip("/").rsplit("/", 1)[0]
if not sse_dir:
return endpoint_path

sse_segments = _path_segments(sse_dir)
endpoint_segments = _path_segments(endpoint_path)

# The backend may emit its route path (for example /v1/messages) while the
# public SSE URL includes a gateway prefix (/gateway/deployment/v1/sse).
# Only infer a prefix when the paths overlap; no-overlap absolute paths are
# ambiguous and retain normal origin-rooted URL semantics.
overlap = 0
for overlap_size in range(min(len(sse_segments), len(endpoint_segments)), 0, -1):
if sse_segments[-overlap_size:] == endpoint_segments[:overlap_size]:
overlap = overlap_size
break

if overlap == 0:
return endpoint_path

prefix_segments = sse_segments[: len(sse_segments) - overlap]
if not prefix_segments:
return endpoint_path

return f"/{'/'.join(prefix_segments)}{endpoint_path}"


def _resolve_endpoint_url(url: str, endpoint: str) -> str:
endpoint_parsed = urlparse(endpoint)
if endpoint_parsed.scheme or endpoint_parsed.netloc:
return urljoin(url, endpoint)

if not endpoint_parsed.path.startswith("/"):
return urljoin(url, endpoint)

url_parsed = urlparse(url)
endpoint_path = _resolve_prefixed_path(url_parsed.path, endpoint_parsed.path)
return urlunparse(
(
url_parsed.scheme,
url_parsed.netloc,
endpoint_path,
endpoint_parsed.params,
endpoint_parsed.query,
endpoint_parsed.fragment,
)
)


@asynccontextmanager
async def sse_client(
url: str,
Expand Down Expand Up @@ -68,7 +122,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
endpoint_url = _resolve_endpoint_url(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")

url_parsed = urlparse(url)
Expand Down
99 changes: 97 additions & 2 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from urllib.parse import urlparse

Expand All @@ -19,12 +19,13 @@
import mcp.client.sse
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client
from mcp.server import Server, ServerRequestContext
from mcp.server.sse import SseServerTransport
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._httpx_utils import McpHttpClientFactory
from mcp.shared.exceptions import MCPError
from mcp.shared.message import SessionMessage
from mcp.types import (
CallToolRequestParams,
CallToolResult,
Expand Down Expand Up @@ -173,6 +174,100 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
assert _extract_session_id_from_endpoint(endpoint_url) == expected


@pytest.mark.parametrize(
("sse_url", "endpoint", "expected_url"),
[
(
"https://example.com/gateway/deployment/v1/sse",
"/v1/messages/?session_id=abc123",
"https://example.com/gateway/deployment/v1/messages/?session_id=abc123",
),
(
"https://example.com/gateway/deployment/v1/sse",
"/gateway/deployment/v1/messages/?session_id=abc123",
"https://example.com/gateway/deployment/v1/messages/?session_id=abc123",
),
(
"https://example.com/gateway/deployment/sse",
"/messages/?session_id=abc123",
"https://example.com/messages/?session_id=abc123",
),
(
"https://example.com/sse",
"/messages/?session_id=abc123",
"https://example.com/messages/?session_id=abc123",
),
(
"https://example.com/gateway/sse",
"/",
"https://example.com/",
),
(
"https://example.com/gateway/deployment/v1/sse",
"messages/?session_id=abc123",
"https://example.com/gateway/deployment/v1/messages/?session_id=abc123",
),
(
"https://example.com/gateway/deployment/v1/sse",
"https://example.com/messages/?session_id=abc123",
"https://example.com/messages/?session_id=abc123",
),
],
)
def test_resolve_endpoint_url_preserves_gateway_path_prefix(sse_url: str, endpoint: str, expected_url: str) -> None:
assert _resolve_endpoint_url(sse_url, endpoint) == expected_url


@pytest.mark.anyio
async def test_sse_client_posts_to_endpoint_with_gateway_path_prefix() -> None:
"""A gateway prefix on the public SSE URL is preserved for absolute-path endpoint events."""
posted = anyio.Event()
posted_urls: list[str] = []

async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123")
await anyio.sleep_forever()

mock_event_source = MagicMock()
mock_event_source.aiter_sse.return_value = mock_aiter_sse()
mock_event_source.response = MagicMock()
mock_event_source.response.raise_for_status = MagicMock()

mock_aconnect_sse = MagicMock()
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)

async def mock_post(url: str, **kwargs: Any) -> MagicMock:
posted_urls.append(url)
posted.set()
return MagicMock(status_code=200, raise_for_status=MagicMock())

mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client.post = AsyncMock(side_effect=mock_post)

def mock_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
return cast(httpx.AsyncClient, mock_client)

with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse):
async with sse_client(
"http://test/gateway/deployment/v1/sse",
httpx_client_factory=mock_factory,
) as (_, write_stream):
request = types.JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
await write_stream.send(SessionMessage(request))

with anyio.fail_after(5): # pragma: no branch
await posted.wait()

assert posted_urls == ["http://test/gateway/deployment/v1/messages/?session_id=abc123"]


@pytest.mark.anyio
async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
"""No session-created callback fires when the endpoint URL carries no session ID."""
Expand Down
Loading