Skip to content

Commit b6e2843

Browse files
committed
Store priming event before request dispatch so storage order == wire order
Splits the old _maybe_send_priming_event into _mint_priming_event (store + return wire dict) and _run_sse_writer (forward request_stream onto the wire). The POST handler now awaits _mint_priming_event before writer.send(), so the priming row is in the event store before the server can produce any message for that request id — ordering by data dependency, not scheduler timing. The replay path keeps its priming event (test_streamable_http_multiple_reconnections relies on it as a stream-re-registered signal); its replay→live-tail ordering window is pre-existing and orthogonal. Also extracts the inline sse_writer closure to a method (drops _handle_post_request below the C901 threshold) and widens the SSE-dict stream type to SSEEvent (dict[str, Any]) — the previous dict[str, str] was a lie masked by the old helper's Any parameter, since priming events carry retry: int.
1 parent fee7012 commit b6e2843

3 files changed

Lines changed: 122 additions & 134 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections.abc import AsyncGenerator, Awaitable, Callable
1313
from contextlib import asynccontextmanager
1414
from dataclasses import dataclass
15+
from functools import partial
1516
from http import HTTPStatus
1617
from typing import Any, Final
1718

@@ -71,6 +72,8 @@
7172
# Type aliases
7273
StreamId = str
7374
EventId = str
75+
# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`).
76+
SSEEvent = dict[str, Any]
7477

7578

7679
@dataclass
@@ -174,7 +177,7 @@ def __init__(
174177
MemoryObjectReceiveStream[EventMessage],
175178
],
176179
] = {}
177-
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
180+
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {}
178181
self._terminated = False
179182
# Idle timeout cancel scope; managed by the session manager.
180183
self.idle_scope: anyio.CancelScope | None = None
@@ -262,31 +265,48 @@ async def close_standalone_stream_callback() -> None:
262265

263266
return SessionMessage(message, metadata=metadata)
264267

265-
async def _maybe_send_priming_event(
266-
self,
267-
request_id: RequestId,
268-
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
269-
protocol_version: str,
270-
) -> None:
271-
"""Send priming event for SSE resumability if event_store is configured.
268+
async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None:
269+
"""Store the priming cursor for `stream_id` and return its SSE wire form.
272270
273-
Only sends priming events to clients with protocol version >= 2025-11-25,
274-
which includes the fix for handling empty SSE data. Older clients would
275-
crash trying to parse empty data as JSON.
271+
Called before the request is dispatched so the priming row precedes
272+
anything `message_router` can store for this stream. Returns `None`
273+
when no event store is configured or the client predates 2025-11-25
274+
(older clients cannot parse the empty-data event).
276275
"""
277276
if not self._event_store:
278-
return
279-
# Priming events have empty data which older clients cannot handle.
277+
return None
280278
if not is_version_at_least(protocol_version, "2025-11-25"):
281-
return
282-
priming_event_id = await self._event_store.store_event(
283-
str(request_id), # Convert RequestId to StreamId (str)
284-
None, # Priming event has no payload
285-
)
286-
priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""}
279+
return None
280+
priming_event_id = await self._event_store.store_event(stream_id, None)
281+
priming_event: SSEEvent = {"id": priming_event_id, "data": ""}
287282
if self._retry_interval is not None:
288283
priming_event["retry"] = self._retry_interval
289-
await sse_stream_writer.send(priming_event)
284+
return priming_event
285+
286+
async def _run_sse_writer(
287+
self,
288+
request_id: RequestId,
289+
sse_stream_writer: MemoryObjectSendStream[SSEEvent],
290+
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
291+
priming_event: SSEEvent | None,
292+
) -> None:
293+
"""Forward `_request_streams[request_id]` onto the SSE wire for one POST."""
294+
try:
295+
async with sse_stream_writer, request_stream_reader:
296+
if priming_event is not None:
297+
await sse_stream_writer.send(priming_event)
298+
async for event_message in request_stream_reader:
299+
await sse_stream_writer.send(self._create_event_data(event_message))
300+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
301+
break
302+
except anyio.ClosedResourceError: # pragma: lax no cover
303+
logger.debug("SSE stream closed by close_sse_stream()")
304+
except Exception: # pragma: lax no cover
305+
logger.exception("Error in SSE writer")
306+
finally:
307+
logger.debug("Closing SSE writer")
308+
self._sse_stream_writers.pop(request_id, None)
309+
await self._clean_up_memory_streams(request_id)
290310

291311
def _create_error_response(
292312
self,
@@ -340,7 +360,7 @@ def _get_session_id(self, request: Request) -> str | None:
340360
"""Extract the session ID from request headers."""
341361
return request.headers.get(MCP_SESSION_ID_HEADER)
342362

343-
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
363+
def _create_event_data(self, event_message: EventMessage) -> SSEEvent:
344364
"""Create event data dictionary from an EventMessage."""
345365
event_data = {
346366
"event": "message",
@@ -583,40 +603,16 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
583603
await self._clean_up_memory_streams(request_id)
584604
else:
585605
# Create SSE stream
586-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
606+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
587607

588608
# Store writer reference so close_sse_stream() can close it
589609
self._sse_stream_writers[request_id] = sse_stream_writer
590610

591-
async def sse_writer():
592-
# Get the request ID from the incoming request message
593-
try:
594-
async with sse_stream_writer, request_stream_reader:
595-
# Send priming event for SSE resumability
596-
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)
597-
598-
# Process messages from the request-specific stream
599-
async for event_message in request_stream_reader:
600-
# Build the event data
601-
event_data = self._create_event_data(event_message)
602-
await sse_stream_writer.send(event_data)
603-
604-
# If response, remove from pending streams and close
605-
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
606-
break
607-
except anyio.ClosedResourceError: # pragma: lax no cover
608-
# Expected when close_sse_stream() is called
609-
logger.debug("SSE stream closed by close_sse_stream()")
610-
except Exception: # pragma: lax no cover
611-
logger.exception("Error in SSE writer")
612-
finally:
613-
logger.debug("Closing SSE writer")
614-
self._sse_stream_writers.pop(request_id, None)
615-
await self._clean_up_memory_streams(request_id)
616-
617-
# Create and start EventSourceResponse
618-
# SSE stream mode (original behavior)
619-
# Set up headers
611+
# Store the priming event before the request is dispatched so its
612+
# event-store position precedes anything message_router can store
613+
# for this id (storage order == wire order by construction).
614+
priming_event = await self._mint_priming_event(request_id, protocol_version)
615+
620616
headers = {
621617
"Cache-Control": "no-cache, no-transform",
622618
"Connection": "keep-alive",
@@ -625,7 +621,9 @@ async def sse_writer():
625621
}
626622
response = EventSourceResponse(
627623
content=sse_stream_reader,
628-
data_sender_callable=sse_writer,
624+
data_sender_callable=partial(
625+
self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event
626+
),
629627
headers=headers,
630628
)
631629

@@ -708,7 +706,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
708706
return
709707

710708
# Create SSE stream
711-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
709+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
712710

713711
async def standalone_sse_writer():
714712
try:
@@ -903,11 +901,10 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
903901
if self.mcp_session_id: # pragma: no branch
904902
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
905903

906-
# Get protocol version from header (already validated in _validate_protocol_version)
907904
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
908905

909906
# Create SSE stream for replay
910-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
907+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
911908

912909
async def replay_sender():
913910
try:
@@ -925,8 +922,12 @@ async def send_event(event_message: EventMessage) -> None:
925922
# Register SSE writer so close_sse_stream() can close it
926923
self._sse_stream_writers[stream_id] = sse_stream_writer
927924

928-
# Send priming event for this new connection
929-
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)
925+
# Prime the resumed connection so the client sees the stream
926+
# is re-registered. The replay→live-tail ordering window here
927+
# is pre-existing and tracked separately.
928+
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
929+
if priming_event is not None: # pragma: no branch
930+
await sse_stream_writer.send(priming_event)
930931

931932
# Create new request streams for this connection
932933
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](

tests/server/test_streamable_http_router.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,27 @@
55

66
from mcp.server.streamable_http import (
77
REQUEST_STREAM_BUFFER_SIZE,
8+
EventCallback,
9+
EventId,
810
EventMessage,
11+
EventStore,
912
StreamableHTTPServerTransport,
13+
StreamId,
1014
)
1115
from mcp.shared.message import SessionMessage
12-
from mcp.types import JSONRPCResponse
16+
from mcp.types import JSONRPCMessage, JSONRPCResponse
17+
18+
19+
class _OrderTrackingStore(EventStore):
20+
def __init__(self) -> None:
21+
self.stored: list[tuple[StreamId, JSONRPCMessage | None]] = []
22+
23+
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId:
24+
self.stored.append((stream_id, message))
25+
return str(len(self.stored))
26+
27+
async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None:
28+
raise NotImplementedError
1329

1430

1531
@pytest.mark.anyio
@@ -49,3 +65,37 @@ async def server_writes() -> None:
4965
assert a_send.statistics().current_buffer_used == 1
5066
await a_recv.aclose()
5167
await a_send.aclose()
68+
69+
70+
@pytest.mark.anyio
71+
async def test_priming_event_is_stored_before_any_routed_message() -> None:
72+
"""`_mint_priming_event` is awaited before the request is dispatched, so the
73+
priming row precedes every `message_router` store for that stream regardless
74+
of when `sse_writer` is scheduled.
75+
"""
76+
store = _OrderTrackingStore()
77+
transport = StreamableHTTPServerTransport(mcp_session_id="sid", is_json_response_enabled=False, event_store=store)
78+
streams = transport._request_streams
79+
80+
async with transport.connect() as (_read_stream, write_stream):
81+
# POST handler step: mint priming for "A" before dispatch.
82+
priming = await transport._mint_priming_event("A", "2025-11-25")
83+
assert priming is not None
84+
streams["A"] = anyio.create_memory_object_stream[EventMessage](REQUEST_STREAM_BUFFER_SIZE)
85+
a_send, a_recv = streams["A"]
86+
87+
# Server emits 5 messages for "A" with no sse_writer scheduled. Each
88+
# write_stream.send() rendezvous-hands to message_router, which stores
89+
# then deposits into A's buffer; reading them back proves the router
90+
# has finished storing.
91+
for i in range(5):
92+
await write_stream.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id="A", result={"n": i})))
93+
with anyio.fail_after(5):
94+
for _ in range(5):
95+
await a_recv.receive()
96+
await a_recv.aclose()
97+
await a_send.aclose()
98+
99+
assert store.stored[0] == ("A", None)
100+
assert [sid for sid, _ in store.stored] == ["A"] * 6
101+
assert all(msg is not None for _, msg in store.stored[1:])

tests/shared/test_streamable_http.py

Lines changed: 13 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,80 +1643,30 @@ async def test_handle_sse_event_skips_empty_data() -> None:
16431643

16441644
@pytest.mark.anyio
16451645
async def test_priming_event_not_sent_for_old_protocol_version() -> None:
1646-
"""_maybe_send_priming_event skips for old protocol versions (backwards compat)."""
1647-
# Create a transport with an event store
1648-
transport = StreamableHTTPServerTransport(
1649-
"/mcp",
1650-
event_store=SimpleEventStore(),
1651-
)
1652-
1653-
# Create a mock stream writer
1654-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1655-
1656-
try:
1657-
# Call _maybe_send_priming_event with OLD protocol version - should NOT send
1658-
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-06-18")
1659-
1660-
# Nothing should have been written to the stream
1661-
assert write_stream.statistics().current_buffer_used == 0
1662-
1663-
# Now test with NEW protocol version - should send
1664-
await transport._maybe_send_priming_event("test-request-id-2", write_stream, "2025-11-25")
1665-
1666-
# Should have written a priming event
1667-
assert write_stream.statistics().current_buffer_used == 1
1668-
finally:
1669-
await write_stream.aclose()
1670-
await read_stream.aclose()
1646+
"""`_mint_priming_event` skips for old protocol versions (backwards compat)."""
1647+
transport = StreamableHTTPServerTransport("/mcp", event_store=SimpleEventStore())
1648+
assert await transport._mint_priming_event("test-request-id", "2025-06-18") is None
1649+
assert await transport._mint_priming_event("test-request-id-2", "2025-11-25") is not None
16711650

16721651

16731652
@pytest.mark.anyio
16741653
async def test_priming_event_not_sent_without_event_store() -> None:
1675-
"""_maybe_send_priming_event returns early when no event_store is configured."""
1676-
# Create a transport WITHOUT an event store
1654+
"""`_mint_priming_event` returns `None` when no event_store is configured."""
16771655
transport = StreamableHTTPServerTransport("/mcp")
1678-
1679-
# Create a mock stream writer
1680-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1681-
1682-
try:
1683-
# Call _maybe_send_priming_event - should return early without sending
1684-
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25")
1685-
1686-
# Nothing should have been written to the stream
1687-
assert write_stream.statistics().current_buffer_used == 0
1688-
finally:
1689-
await write_stream.aclose()
1690-
await read_stream.aclose()
1656+
assert await transport._mint_priming_event("test-request-id", "2025-11-25") is None
16911657

16921658

16931659
@pytest.mark.anyio
16941660
async def test_priming_event_includes_retry_interval() -> None:
1695-
"""_maybe_send_priming_event includes the retry field when retry_interval is set."""
1696-
# Create a transport with an event store AND retry_interval
1661+
"""`_mint_priming_event` includes the retry field when `retry_interval` is set."""
16971662
transport = StreamableHTTPServerTransport(
16981663
"/mcp",
16991664
event_store=SimpleEventStore(),
17001665
retry_interval=5000,
17011666
)
1702-
1703-
# Create a mock stream writer
1704-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1705-
1706-
try:
1707-
# Call _maybe_send_priming_event with new protocol version
1708-
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25")
1709-
1710-
# Should have written a priming event with retry field
1711-
assert write_stream.statistics().current_buffer_used == 1
1712-
1713-
# Read the event and verify it has retry field
1714-
event = await read_stream.receive()
1715-
assert "retry" in event
1716-
assert event["retry"] == 5000
1717-
finally:
1718-
await write_stream.aclose()
1719-
await read_stream.aclose()
1667+
event = await transport._mint_priming_event("test-request-id", "2025-11-25")
1668+
assert event is not None
1669+
assert event["retry"] == 5000
17201670

17211671

17221672
@pytest.mark.anyio
@@ -1753,26 +1703,13 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version()
17531703

17541704
@pytest.mark.anyio
17551705
async def test_priming_event_not_sent_for_unknown_protocol_version() -> None:
1756-
"""_maybe_send_priming_event treats unrecognized version strings conservatively.
1706+
"""`_mint_priming_event` treats unrecognized version strings conservatively.
17571707
17581708
A garbage version must not be mistaken for a future one (lexicographically
17591709
"zzz" sorts after every date-shaped revision).
17601710
"""
1761-
transport = StreamableHTTPServerTransport(
1762-
"/mcp",
1763-
event_store=SimpleEventStore(),
1764-
)
1765-
1766-
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)
1767-
1768-
try:
1769-
await transport._maybe_send_priming_event("test-request-id", write_stream, "zzz")
1770-
1771-
# Nothing should have been written to the stream
1772-
assert write_stream.statistics().current_buffer_used == 0
1773-
finally:
1774-
await write_stream.aclose()
1775-
await read_stream.aclose()
1711+
transport = StreamableHTTPServerTransport("/mcp", event_store=SimpleEventStore())
1712+
assert await transport._mint_priming_event("test-request-id", "zzz") is None
17761713

17771714

17781715
@pytest.mark.anyio

0 commit comments

Comments
 (0)