Skip to content

Commit c0dfcb3

Browse files
committed
fix: reject initialize protocol version conflicts
1 parent 2397319 commit c0dfcb3

2 files changed

Lines changed: 82 additions & 17 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -476,23 +476,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
476476
await response(scope, receive, send)
477477
return
478478

479-
# Check if this is an initialization request
480-
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"
481-
482-
if is_initialization_request:
483-
# Check if the server already has an established session
484-
if self.mcp_session_id:
485-
# Check if request has a session ID
486-
request_session_id = self._get_session_id(request)
487-
488-
# If request has a session ID but doesn't match, return 404
489-
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
490-
response = self._create_error_response(
491-
"Not Found: Invalid or expired session ID",
492-
HTTPStatus.NOT_FOUND,
493-
)
494-
await response(scope, receive, send)
495-
return
479+
is_initialization_request = False
480+
if isinstance(message, JSONRPCRequest) and message.method == "initialize":
481+
is_initialization_request = True
482+
if not await self._validate_initialization_request(message, request, send):
483+
return
496484
elif not await self._validate_request_headers(request, send):
497485
return
498486

@@ -848,6 +836,44 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
848836

849837
return True
850838

839+
async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool:
840+
if not await self._validate_initialization_protocol_version(message, request, send):
841+
return False
842+
843+
if not self.mcp_session_id:
844+
return True
845+
846+
request_session_id = self._get_session_id(request)
847+
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
848+
response = self._create_error_response(
849+
"Not Found: Invalid or expired session ID",
850+
HTTPStatus.NOT_FOUND,
851+
)
852+
await response(request.scope, request.receive, send)
853+
return False
854+
855+
return True
856+
857+
async def _validate_initialization_protocol_version(
858+
self, message: JSONRPCRequest, request: Request, send: Send
859+
) -> bool:
860+
header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
861+
body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None
862+
if (
863+
header_protocol_version is not None
864+
and body_protocol_version is not None
865+
and header_protocol_version != body_protocol_version
866+
):
867+
response = self._create_error_response(
868+
f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion",
869+
HTTPStatus.BAD_REQUEST,
870+
INVALID_REQUEST,
871+
)
872+
await response(request.scope, request.receive, send)
873+
return False
874+
875+
return True
876+
851877
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
852878
"""Replays events that would have been sent after the specified event ID.
853879

tests/shared/test_streamable_http.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,45 @@ async def test_server_validates_protocol_version_header(basic_app: Starlette) ->
15401540
assert response.status_code == 200
15411541

15421542

1543+
@pytest.mark.anyio
1544+
@pytest.mark.parametrize(
1545+
("header_version", "body_version"),
1546+
[
1547+
("2025-03-26", "2025-06-18"),
1548+
("2025-06-18", "2025-03-26"),
1549+
],
1550+
)
1551+
async def test_server_rejects_initialize_protocol_version_mismatch(
1552+
basic_app: Starlette, header_version: str, body_version: str
1553+
) -> None:
1554+
"""initialize is rejected with 400 when the header and body protocol versions disagree."""
1555+
init_request: dict[str, Any] = {
1556+
"jsonrpc": "2.0",
1557+
"method": "initialize",
1558+
"params": {
1559+
"clientInfo": {"name": "test-client", "version": "1.0"},
1560+
"protocolVersion": body_version,
1561+
"capabilities": {},
1562+
},
1563+
"id": "init-1",
1564+
}
1565+
1566+
async with make_client(basic_app) as client:
1567+
response = await client.post(
1568+
"/mcp",
1569+
headers={
1570+
"Accept": "application/json, text/event-stream",
1571+
"Content-Type": "application/json",
1572+
MCP_PROTOCOL_VERSION_HEADER: header_version,
1573+
},
1574+
json=init_request,
1575+
)
1576+
1577+
assert response.status_code == 400
1578+
assert MCP_PROTOCOL_VERSION_HEADER in response.text
1579+
assert "protocolVersion" in response.text
1580+
1581+
15431582
@pytest.mark.anyio
15441583
async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None:
15451584
"""A request without a protocol version header is accepted for backwards compatibility."""

0 commit comments

Comments
 (0)