Skip to content

Commit dde2df9

Browse files
committed
Store priming event before request dispatch so storage order == wire order
Splits the old _maybe_send_priming_event into _mint_priming_event (store + return wire dict) and _run_sse_writer (forward request_stream onto the wire). The POST handler now awaits _mint_priming_event before writer.send(), so the priming row is in the event store before the server can produce any message for that request id — ordering by data dependency, not scheduler timing. The replay path keeps its priming event (test_streamable_http_multiple_reconnections relies on it as a stream-re-registered signal); its replay→live-tail ordering window is pre-existing and orthogonal. Also extracts the inline sse_writer closure to a method (drops _handle_post_request below the C901 threshold) and widens the SSE-dict stream type to SSEEvent (dict[str, Any]) — the previous dict[str, str] was a lie masked by the old helper's Any parameter, since priming events carry retry: int.
1 parent b0b398c commit dde2df9

3 files changed

Lines changed: 122 additions & 133 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections.abc import AsyncGenerator, Awaitable, Callable
1313
from contextlib import asynccontextmanager
1414
from dataclasses import dataclass
15+
from functools import partial
1516
from http import HTTPStatus
1617
from typing import Any, Final
1718

@@ -71,6 +72,8 @@
7172
# Type aliases
7273
StreamId = str
7374
EventId = str
75+
# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`).
76+
SSEEvent = dict[str, Any]
7477

7578

7679
@dataclass
@@ -174,7 +177,7 @@ def __init__(
174177
MemoryObjectReceiveStream[EventMessage],
175178
],
176179
] = {}
177-
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
180+
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {}
178181
self._terminated = False
179182
# Idle timeout cancel scope; managed by the session manager.
180183
self.idle_scope: anyio.CancelScope | None = None
@@ -261,31 +264,48 @@ async def close_standalone_stream_callback() -> None:
261264

262265
return SessionMessage(message, metadata=metadata)
263266

264-
async def _maybe_send_priming_event(
265-
self,
266-
request_id: RequestId,
267-
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
268-
protocol_version: str,
269-
) -> None:
270-
"""Send priming event for SSE resumability if event_store is configured.
267+
async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None:
268+
"""Store the priming cursor for `stream_id` and return its SSE wire form.
271269
272-
Only sends priming events to clients with protocol version >= 2025-11-25,
273-
which includes the fix for handling empty SSE data. Older clients would
274-
crash trying to parse empty data as JSON.
270+
Called before the request is dispatched so the priming row precedes
271+
anything `message_router` can store for this stream. Returns `None`
272+
when no event store is configured or the client predates 2025-11-25
273+
(older clients cannot parse the empty-data event).
275274
"""
276275
if not self._event_store:
277-
return
278-
# Priming events have empty data which older clients cannot handle.
276+
return None
279277
if not is_version_at_least(protocol_version, "2025-11-25"):
280-
return
281-
priming_event_id = await self._event_store.store_event(
282-
str(request_id), # Convert RequestId to StreamId (str)
283-
None, # Priming event has no payload
284-
)
285-
priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""}
278+
return None
279+
priming_event_id = await self._event_store.store_event(stream_id, None)
280+
priming_event: SSEEvent = {"id": priming_event_id, "data": ""}
286281
if self._retry_interval is not None:
287282
priming_event["retry"] = self._retry_interval
288-
await sse_stream_writer.send(priming_event)
283+
return priming_event
284+
285+
async def _run_sse_writer(
286+
self,
287+
request_id: RequestId,
288+
sse_stream_writer: MemoryObjectSendStream[SSEEvent],
289+
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
290+
priming_event: SSEEvent | None,
291+
) -> None:
292+
"""Forward `_request_streams[request_id]` onto the SSE wire for one POST."""
293+
try:
294+
async with sse_stream_writer, request_stream_reader:
295+
if priming_event is not None:
296+
await sse_stream_writer.send(priming_event)
297+
async for event_message in request_stream_reader:
298+
await sse_stream_writer.send(self._create_event_data(event_message))
299+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
300+
break
301+
except anyio.ClosedResourceError: # pragma: lax no cover
302+
logger.debug("SSE stream closed by close_sse_stream()")
303+
except Exception: # pragma: lax no cover
304+
logger.exception("Error in SSE writer")
305+
finally:
306+
logger.debug("Closing SSE writer")
307+
self._sse_stream_writers.pop(request_id, None)
308+
await self._clean_up_memory_streams(request_id)
289309

290310
def _create_error_response(
291311
self,
@@ -339,7 +359,7 @@ def _get_session_id(self, request: Request) -> str | None:
339359
"""Extract the session ID from request headers."""
340360
return request.headers.get(MCP_SESSION_ID_HEADER)
341361

342-
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
362+
def _create_event_data(self, event_message: EventMessage) -> SSEEvent:
343363
"""Create event data dictionary from an EventMessage."""
344364
event_data = {
345365
"event": "message",
@@ -579,40 +599,16 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
579599
await self._clean_up_memory_streams(request_id)
580600
else:
581601
# Create SSE stream
582-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
602+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
583603

584604
# Store writer reference so close_sse_stream() can close it
585605
self._sse_stream_writers[request_id] = sse_stream_writer
586606

587-
async def sse_writer():
588-
# Get the request ID from the incoming request message
589-
try:
590-
async with sse_stream_writer, request_stream_reader:
591-
# Send priming event for SSE resumability
592-
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)
593-
594-
# Process messages from the request-specific stream
595-
async for event_message in request_stream_reader:
596-
# Build the event data
597-
event_data = self._create_event_data(event_message)
598-
await sse_stream_writer.send(event_data)
599-
600-
# If response, remove from pending streams and close
601-
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
602-
break
603-
except anyio.ClosedResourceError: # pragma: lax no cover
604-
# Expected when close_sse_stream() is called
605-
logger.debug("SSE stream closed by close_sse_stream()")
606-
except Exception: # pragma: lax no cover
607-
logger.exception("Error in SSE writer")
608-
finally:
609-
logger.debug("Closing SSE writer")
610-
self._sse_stream_writers.pop(request_id, None)
611-
await self._clean_up_memory_streams(request_id)
612-
613-
# Create and start EventSourceResponse
614-
# SSE stream mode (original behavior)
615-
# Set up headers
607+
# Store the priming event before the request is dispatched so its
608+
# event-store position precedes anything message_router can store
609+
# for this id (storage order == wire order by construction).
610+
priming_event = await self._mint_priming_event(request_id, protocol_version)
611+
616612
headers = {
617613
"Cache-Control": "no-cache, no-transform",
618614
"Connection": "keep-alive",
@@ -621,7 +617,9 @@ async def sse_writer():
621617
}
622618
response = EventSourceResponse(
623619
content=sse_stream_reader,
624-
data_sender_callable=sse_writer,
620+
data_sender_callable=partial(
621+
self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event
622+
),
625623
headers=headers,
626624
)
627625

@@ -704,7 +702,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
704702
return
705703

706704
# Create SSE stream
707-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
705+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
708706

709707
async def standalone_sse_writer():
710708
try:
@@ -880,7 +878,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
880878
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
881879

882880
# Create SSE stream for replay
883-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
881+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
884882

885883
async def replay_sender():
886884
try:
@@ -898,8 +896,12 @@ async def send_event(event_message: EventMessage) -> None:
898896
# Register SSE writer so close_sse_stream() can close it
899897
self._sse_stream_writers[stream_id] = sse_stream_writer
900898

901-
# Send priming event for this new connection
902-
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)
899+
# Prime the resumed connection so the client sees the stream
900+
# is re-registered. The replay→live-tail ordering window here
901+
# is pre-existing and tracked separately.
902+
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
903+
if priming_event is not None: # pragma: no branch
904+
await sse_stream_writer.send(priming_event)
903905

904906
# Create new request streams for this connection
905907
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](

tests/server/test_streamable_http_router.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,27 @@
55

66
from mcp.server.streamable_http import (
77
REQUEST_STREAM_BUFFER_SIZE,
8+
EventCallback,
9+
EventId,
810
EventMessage,
11+
EventStore,
912
StreamableHTTPServerTransport,
13+
StreamId,
1014
)
1115
from mcp.shared.message import SessionMessage
12-
from mcp.types import JSONRPCResponse
16+
from mcp.types import JSONRPCMessage, JSONRPCResponse
17+
18+
19+
class _OrderTrackingStore(EventStore):
20+
def __init__(self) -> None:
21+
self.stored: list[tuple[StreamId, JSONRPCMessage | None]] = []
22+
23+
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId:
24+
self.stored.append((stream_id, message))
25+
return str(len(self.stored))
26+
27+
async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None:
28+
raise NotImplementedError
1329

1430

1531
@pytest.mark.anyio
@@ -49,3 +65,37 @@ async def server_writes() -> None:
4965
assert a_send.statistics().current_buffer_used == 1
5066
await a_recv.aclose()
5167
await a_send.aclose()
68+
69+
70+
@pytest.mark.anyio
71+
async def test_priming_event_is_stored_before_any_routed_message() -> None:
72+
"""`_mint_priming_event` is awaited before the request is dispatched, so the
73+
priming row precedes every `message_router` store for that stream regardless
74+
of when `sse_writer` is scheduled.
75+
"""
76+
store = _OrderTrackingStore()
77+
transport = StreamableHTTPServerTransport(mcp_session_id="sid", is_json_response_enabled=False, event_store=store)
78+
streams = transport._request_streams
79+
80+
async with transport.connect() as (_read_stream, write_stream):
81+
# POST handler step: mint priming for "A" before dispatch.
82+
priming = await transport._mint_priming_event("A", "2025-11-25")
83+
assert priming is not None
84+
streams["A"] = anyio.create_memory_object_stream[EventMessage](REQUEST_STREAM_BUFFER_SIZE)
85+
a_send, a_recv = streams["A"]
86+
87+
# Server emits 5 messages for "A" with no sse_writer scheduled. Each
88+
# write_stream.send() rendezvous-hands to message_router, which stores
89+
# then deposits into A's buffer; reading them back proves the router
90+
# has finished storing.
91+
for i in range(5):
92+
await write_stream.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id="A", result={"n": i})))
93+
with anyio.fail_after(5):
94+
for _ in range(5):
95+
await a_recv.receive()
96+
await a_recv.aclose()
97+
await a_send.aclose()
98+
99+
assert store.stored[0] == ("A", None)
100+
assert [sid for sid, _ in store.stored] == ["A"] * 6
101+
assert all(msg is not None for _, msg in store.stored[1:])

tests/shared/test_streamable_http.py

Lines changed: 13 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,80 +1636,30 @@ async def test_handle_sse_event_skips_empty_data() -> None:
16361636

16371637
@pytest.mark.anyio
16381638
async def test_priming_event_not_sent_for_old_protocol_version() -> None:
1639-
"""_maybe_send_priming_event skips for old protocol versions (backwards compat)."""
1640-
# Create a transport with an event store
1641-
transport = StreamableHTTPServerTransport(
1642-
"/mcp",
1643-
event_store=SimpleEventStore(),
1644-
)
1645-
1646-
# Create a mock stream writer
1647-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1648-
1649-
try:
1650-
# Call _maybe_send_priming_event with OLD protocol version - should NOT send
1651-
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-06-18")
1652-
1653-
# Nothing should have been written to the stream
1654-
assert write_stream.statistics().current_buffer_used == 0
1655-
1656-
# Now test with NEW protocol version - should send
1657-
await transport._maybe_send_priming_event("test-request-id-2", write_stream, "2025-11-25")
1658-
1659-
# Should have written a priming event
1660-
assert write_stream.statistics().current_buffer_used == 1
1661-
finally:
1662-
await write_stream.aclose()
1663-
await read_stream.aclose()
1639+
"""`_mint_priming_event` skips for old protocol versions (backwards compat)."""
1640+
transport = StreamableHTTPServerTransport("/mcp", event_store=SimpleEventStore())
1641+
assert await transport._mint_priming_event("test-request-id", "2025-06-18") is None
1642+
assert await transport._mint_priming_event("test-request-id-2", "2025-11-25") is not None
16641643

16651644

16661645
@pytest.mark.anyio
16671646
async def test_priming_event_not_sent_without_event_store() -> None:
1668-
"""_maybe_send_priming_event returns early when no event_store is configured."""
1669-
# Create a transport WITHOUT an event store
1647+
"""`_mint_priming_event` returns `None` when no event_store is configured."""
16701648
transport = StreamableHTTPServerTransport("/mcp")
1671-
1672-
# Create a mock stream writer
1673-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1674-
1675-
try:
1676-
# Call _maybe_send_priming_event - should return early without sending
1677-
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25")
1678-
1679-
# Nothing should have been written to the stream
1680-
assert write_stream.statistics().current_buffer_used == 0
1681-
finally:
1682-
await write_stream.aclose()
1683-
await read_stream.aclose()
1649+
assert await transport._mint_priming_event("test-request-id", "2025-11-25") is None
16841650

16851651

16861652
@pytest.mark.anyio
16871653
async def test_priming_event_includes_retry_interval() -> None:
1688-
"""_maybe_send_priming_event includes the retry field when retry_interval is set."""
1689-
# Create a transport with an event store AND retry_interval
1654+
"""`_mint_priming_event` includes the retry field when `retry_interval` is set."""
16901655
transport = StreamableHTTPServerTransport(
16911656
"/mcp",
16921657
event_store=SimpleEventStore(),
16931658
retry_interval=5000,
16941659
)
1695-
1696-
# Create a mock stream writer
1697-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1698-
1699-
try:
1700-
# Call _maybe_send_priming_event with new protocol version
1701-
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25")
1702-
1703-
# Should have written a priming event with retry field
1704-
assert write_stream.statistics().current_buffer_used == 1
1705-
1706-
# Read the event and verify it has retry field
1707-
event = await read_stream.receive()
1708-
assert "retry" in event
1709-
assert event["retry"] == 5000
1710-
finally:
1711-
await write_stream.aclose()
1712-
await read_stream.aclose()
1660+
event = await transport._mint_priming_event("test-request-id", "2025-11-25")
1661+
assert event is not None
1662+
assert event["retry"] == 5000
17131663

17141664

17151665
@pytest.mark.anyio
@@ -1746,26 +1696,13 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version()
17461696

17471697
@pytest.mark.anyio
17481698
async def test_priming_event_not_sent_for_unknown_protocol_version() -> None:
1749-
"""_maybe_send_priming_event treats unrecognized version strings conservatively.
1699+
"""`_mint_priming_event` treats unrecognized version strings conservatively.
17501700
17511701
A garbage version must not be mistaken for a future one (lexicographically
17521702
"zzz" sorts after every date-shaped revision).
17531703
"""
1754-
transport = StreamableHTTPServerTransport(
1755-
"/mcp",
1756-
event_store=SimpleEventStore(),
1757-
)
1758-
1759-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1760-
1761-
try:
1762-
await transport._maybe_send_priming_event("test-request-id", write_stream, "zzz")
1763-
1764-
# Nothing should have been written to the stream
1765-
assert write_stream.statistics().current_buffer_used == 0
1766-
finally:
1767-
await write_stream.aclose()
1768-
await read_stream.aclose()
1704+
transport = StreamableHTTPServerTransport("/mcp", event_store=SimpleEventStore())
1705+
assert await transport._mint_priming_event("test-request-id", "zzz") is None
17691706

17701707

17711708
@pytest.mark.anyio

0 commit comments

Comments
 (0)