Skip to content

Commit 40a3915

Browse files
committed
[v1.x] Buffer per-request StreamableHTTP streams; store priming event before dispatch
Backport of #2934. Gives the per-request _request_streams[EventMessage] sites a small bounded buffer (REQUEST_STREAM_BUFFER_SIZE = 16) so the serial message_router can deposit a response and move on instead of head-of-line blocking the session on a lazily-started sse_writer. Replaces _maybe_send_priming_event with _mint_priming_event (stores + returns the wire dict) and extracts the inline sse_writer closure to _run_sse_writer. The POST handler now mints priming before any per-request state is registered and before the request is dispatched, so the priming row precedes anything message_router can store for that stream by data dependency. Adds a finally to replay_sender; uses a generic 500 body in the outer except. Fixes #1764.
1 parent 32d3290 commit 40a3915

3 files changed

Lines changed: 229 additions & 141 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 99 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
1515
from contextlib import asynccontextmanager
1616
from dataclasses import dataclass
17+
from functools import partial
1718
from http import HTTPStatus
18-
from typing import Any
19+
from typing import Any, Final
1920

2021
import anyio
2122
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -60,13 +61,20 @@
6061
# Special key for the standalone GET stream
6162
GET_STREAM_KEY = "_GET_stream"
6263

64+
# Buffer for the per-request `_request_streams` so the serial `message_router`
65+
# can deposit a response and move on instead of head-of-line blocking the
66+
# whole session on a lazily-started `sse_writer`. See #1764.
67+
REQUEST_STREAM_BUFFER_SIZE: Final = 16
68+
6369
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
6470
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
6571
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
6672

6773
# Type aliases
6874
StreamId = str
6975
EventId = str
76+
# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`).
77+
SSEEvent = dict[str, Any]
7078

7179

7280
@dataclass
@@ -178,7 +186,7 @@ def __init__(
178186
MemoryObjectReceiveStream[EventMessage],
179187
],
180188
] = {}
181-
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
189+
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {}
182190
self._terminated = False
183191
# Idle timeout cancel scope; managed by the session manager.
184192
self.idle_scope: anyio.CancelScope | None = None
@@ -267,31 +275,48 @@ async def close_standalone_stream_callback() -> None:
267275

268276
return SessionMessage(message, metadata=metadata)
269277

270-
async def _maybe_send_priming_event(
271-
self,
272-
request_id: RequestId,
273-
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
274-
protocol_version: str,
275-
) -> None:
276-
"""Send priming event for SSE resumability if event_store is configured.
278+
async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None:
279+
"""Store the priming cursor for `stream_id` and return its SSE wire form.
277280
278-
Only sends priming events to clients with protocol version >= 2025-11-25,
279-
which includes the fix for handling empty SSE data. Older clients would
280-
crash trying to parse empty data as JSON.
281+
Called before the request is dispatched so the priming row precedes
282+
anything `message_router` can store for this stream. Returns `None`
283+
when no event store is configured or the client predates 2025-11-25
284+
(older clients cannot parse the empty-data event).
281285
"""
282286
if not self._event_store:
283-
return
284-
# Priming events have empty data which older clients cannot handle.
287+
return None
285288
if protocol_version < "2025-11-25":
286-
return
287-
priming_event_id = await self._event_store.store_event(
288-
str(request_id), # Convert RequestId to StreamId (str)
289-
None, # Priming event has no payload
290-
)
291-
priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""}
289+
return None
290+
priming_event_id = await self._event_store.store_event(stream_id, None)
291+
priming_event: SSEEvent = {"id": priming_event_id, "data": ""}
292292
if self._retry_interval is not None:
293293
priming_event["retry"] = self._retry_interval
294-
await sse_stream_writer.send(priming_event)
294+
return priming_event
295+
296+
async def _run_sse_writer( # pragma: no cover
297+
self,
298+
request_id: RequestId,
299+
sse_stream_writer: MemoryObjectSendStream[SSEEvent],
300+
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
301+
priming_event: SSEEvent | None,
302+
) -> None:
303+
"""Forward `_request_streams[request_id]` onto the SSE wire for one POST."""
304+
try:
305+
async with sse_stream_writer, request_stream_reader:
306+
if priming_event is not None:
307+
await sse_stream_writer.send(priming_event)
308+
async for event_message in request_stream_reader:
309+
await sse_stream_writer.send(self._create_event_data(event_message))
310+
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
311+
break
312+
except anyio.ClosedResourceError:
313+
logger.debug("SSE stream closed by close_sse_stream()")
314+
except Exception:
315+
logger.exception("Error in SSE writer")
316+
finally:
317+
logger.debug("Closing SSE writer")
318+
self._sse_stream_writers.pop(request_id, None)
319+
await self._clean_up_memory_streams(request_id)
295320

296321
def _create_error_response(
297322
self,
@@ -348,7 +373,7 @@ def _get_session_id(self, request: Request) -> str | None: # pragma: no cover
348373
"""Extract the session ID from request headers."""
349374
return request.headers.get(MCP_SESSION_ID_HEADER)
350375

351-
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover
376+
def _create_event_data(self, event_message: EventMessage) -> SSEEvent: # pragma: no cover
352377
"""Create event data dictionary from an EventMessage."""
353378
event_data = {
354379
"event": "message",
@@ -530,13 +555,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
530555
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
531556
)
532557

533-
# Extract the request ID outside the try block for proper scope
534-
request_id = str(message.root.id) # pragma: no cover
535-
# Register this stream for the request ID
536-
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover
537-
request_stream_reader = self._request_streams[request_id][1] # pragma: no cover
558+
request_id = str(message.root.id)
538559

539560
if self.is_json_response_enabled: # pragma: no cover
561+
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
562+
REQUEST_STREAM_BUFFER_SIZE
563+
)
564+
request_stream_reader = self._request_streams[request_id][1]
540565
# Process the message
541566
metadata = ServerMessageMetadata(request_context=request)
542567
session_message = SessionMessage(message, metadata=metadata)
@@ -580,44 +605,19 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
580605
finally:
581606
await self._clean_up_memory_streams(request_id)
582607
else: # pragma: no cover
583-
# Create SSE stream
584-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
608+
# Mint the priming event before any per-request state exists:
609+
# `EventStore.store_event` is user code and may raise, in which
610+
# case the outer handler returns a 500 with nothing to clean up.
611+
# Still strictly precedes dispatch, so storage order == wire order.
612+
priming_event = await self._mint_priming_event(request_id, protocol_version)
585613

586-
# Store writer reference so close_sse_stream() can close it
614+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
587615
self._sse_stream_writers[request_id] = sse_stream_writer
616+
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
617+
REQUEST_STREAM_BUFFER_SIZE
618+
)
619+
request_stream_reader = self._request_streams[request_id][1]
588620

589-
async def sse_writer():
590-
# Get the request ID from the incoming request message
591-
try:
592-
async with sse_stream_writer, request_stream_reader:
593-
# Send priming event for SSE resumability
594-
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)
595-
596-
# Process messages from the request-specific stream
597-
async for event_message in request_stream_reader:
598-
# Build the event data
599-
event_data = self._create_event_data(event_message)
600-
await sse_stream_writer.send(event_data)
601-
602-
# If response, remove from pending streams and close
603-
if isinstance(
604-
event_message.message.root,
605-
JSONRPCResponse | JSONRPCError,
606-
):
607-
break
608-
except anyio.ClosedResourceError:
609-
# Expected when close_sse_stream() is called
610-
logger.debug("SSE stream closed by close_sse_stream()")
611-
except Exception:
612-
logger.exception("Error in SSE writer")
613-
finally:
614-
logger.debug("Closing SSE writer")
615-
self._sse_stream_writers.pop(request_id, None)
616-
await self._clean_up_memory_streams(request_id)
617-
618-
# Create and start EventSourceResponse
619-
# SSE stream mode (original behavior)
620-
# Set up headers
621621
headers = {
622622
"Cache-Control": "no-cache, no-transform",
623623
"Connection": "keep-alive",
@@ -626,7 +626,9 @@ async def sse_writer():
626626
}
627627
response = EventSourceResponse(
628628
content=sse_stream_reader,
629-
data_sender_callable=sse_writer,
629+
data_sender_callable=partial(
630+
self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event
631+
),
630632
headers=headers,
631633
)
632634

@@ -644,16 +646,15 @@ async def sse_writer():
644646
await sse_stream_reader.aclose()
645647
await self._clean_up_memory_streams(request_id)
646648

647-
except Exception as err: # pragma: no cover
649+
except Exception as err:
648650
logger.exception("Error handling POST request")
649651
response = self._create_error_response(
650-
f"Error handling POST request: {err}",
652+
"Error handling POST request",
651653
HTTPStatus.INTERNAL_SERVER_ERROR,
652654
INTERNAL_ERROR,
653655
)
654656
await response(scope, receive, send)
655-
if writer:
656-
await writer.send(Exception(err))
657+
await writer.send(Exception(err))
657658
return
658659

659660
async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover
@@ -706,13 +707,15 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr
706707
return
707708

708709
# Create SSE stream
709-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
710+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
710711

711712
async def standalone_sse_writer():
712713
try:
713714
# Create a standalone message stream for server-initiated messages
714715

715-
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
716+
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](
717+
REQUEST_STREAM_BUFFER_SIZE
718+
)
716719
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
717720

718721
async with sse_stream_writer, standalone_stream_reader:
@@ -903,7 +906,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
903906
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
904907

905908
# Create SSE stream for replay
906-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
909+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
907910

908911
async def replay_sender():
909912
try:
@@ -918,22 +921,32 @@ async def send_event(event_message: EventMessage) -> None:
918921

919922
# If stream ID not in mapping, create it
920923
if stream_id and stream_id not in self._request_streams:
921-
# Register SSE writer so close_sse_stream() can close it
922-
self._sse_stream_writers[stream_id] = sse_stream_writer
923-
924-
# Send priming event for this new connection
925-
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)
926-
927-
# Create new request streams for this connection
928-
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
929-
msg_reader = self._request_streams[stream_id][1]
930-
931-
# Forward messages to SSE
932-
async with msg_reader:
933-
async for event_message in msg_reader:
934-
event_data = self._create_event_data(event_message)
935-
936-
await sse_stream_writer.send(event_data)
924+
try:
925+
# Register SSE writer so close_sse_stream() can close it
926+
self._sse_stream_writers[stream_id] = sse_stream_writer
927+
928+
# Prime the resumed connection so the client sees the stream
929+
# is re-registered. The replay→live-tail ordering window here
930+
# is pre-existing and tracked separately.
931+
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
932+
if priming_event is not None:
933+
await sse_stream_writer.send(priming_event)
934+
935+
# Create new request streams for this connection
936+
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](
937+
REQUEST_STREAM_BUFFER_SIZE
938+
)
939+
msg_reader = self._request_streams[stream_id][1]
940+
941+
# Forward messages to SSE
942+
async with msg_reader:
943+
async for event_message in msg_reader:
944+
event_data = self._create_event_data(event_message)
945+
946+
await sse_stream_writer.send(event_data)
947+
finally:
948+
self._sse_stream_writers.pop(stream_id, None)
949+
await self._clean_up_memory_streams(stream_id)
937950
except anyio.ClosedResourceError:
938951
# Expected when close_sse_stream() is called
939952
logger.debug("Replay SSE stream closed by close_sse_stream()")

0 commit comments

Comments
 (0)