Skip to content

Commit a527142

Browse files
authored
Buffer per-request StreamableHTTP streams to avoid serial-router head-of-line block (#2934)
1 parent 44ce901 commit a527142

4 files changed

Lines changed: 292 additions & 188 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 99 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from collections.abc import AsyncGenerator, Awaitable, Callable
1313
from contextlib import asynccontextmanager
1414
from dataclasses import dataclass
15+
from functools import partial
1516
from http import HTTPStatus
16-
from typing import Any
17+
from typing import Any, Final
1718

1819
import anyio
1920
import pydantic_core
@@ -59,13 +60,20 @@
5960
# Special key for the standalone GET stream
6061
GET_STREAM_KEY = "_GET_stream"
6162

63+
# Buffer for the per-request `_request_streams` so the serial `message_router`
64+
# can deposit a response and move on instead of head-of-line blocking the
65+
# whole session on a lazily-started `sse_writer`. See #1764.
66+
REQUEST_STREAM_BUFFER_SIZE: Final = 16
67+
6268
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
6369
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
6470
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
6571

6672
# Type aliases
6773
StreamId = str
6874
EventId = str
75+
# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`).
76+
SSEEvent = dict[str, Any]
6977

7078

7179
@dataclass
@@ -169,7 +177,7 @@ def __init__(
169177
MemoryObjectReceiveStream[EventMessage],
170178
],
171179
] = {}
172-
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
180+
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {}
173181
self._terminated = False
174182
# Idle timeout cancel scope; managed by the session manager.
175183
self.idle_scope: anyio.CancelScope | None = None
@@ -256,31 +264,48 @@ async def close_standalone_stream_callback() -> None:
256264

257265
return SessionMessage(message, metadata=metadata)
258266

259-
async def _maybe_send_priming_event(
260-
self,
261-
request_id: RequestId,
262-
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
263-
protocol_version: str,
264-
) -> None:
265-
"""Send priming event for SSE resumability if event_store is configured.
267+
async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None:
268+
"""Store the priming cursor for `stream_id` and return its SSE wire form.
266269
267-
Only sends priming events to clients with protocol version >= 2025-11-25,
268-
which includes the fix for handling empty SSE data. Older clients would
269-
crash trying to parse empty data as JSON.
270+
Called before the request is dispatched so the priming row precedes
271+
anything `message_router` can store for this stream. Returns `None`
272+
when no event store is configured or the client predates 2025-11-25
273+
(older clients cannot parse the empty-data event).
270274
"""
271275
if not self._event_store:
272-
return
273-
# Priming events have empty data which older clients cannot handle.
276+
return None
274277
if not is_version_at_least(protocol_version, "2025-11-25"):
275-
return
276-
priming_event_id = await self._event_store.store_event(
277-
str(request_id), # Convert RequestId to StreamId (str)
278-
None, # Priming event has no payload
279-
)
280-
priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""}
278+
return None
279+
priming_event_id = await self._event_store.store_event(stream_id, None)
280+
priming_event: SSEEvent = {"id": priming_event_id, "data": ""}
281281
if self._retry_interval is not None:
282282
priming_event["retry"] = self._retry_interval
283-
await sse_stream_writer.send(priming_event)
283+
return priming_event
284+
285+
async def _run_sse_writer(
286+
self,
287+
request_id: RequestId,
288+
sse_stream_writer: MemoryObjectSendStream[SSEEvent],
289+
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
290+
priming_event: SSEEvent | None,
291+
) -> None:
292+
"""Forward `_request_streams[request_id]` onto the SSE wire for one POST."""
293+
try:
294+
async with sse_stream_writer, request_stream_reader:
295+
if priming_event is not None:
296+
await sse_stream_writer.send(priming_event)
297+
async for event_message in request_stream_reader:
298+
await sse_stream_writer.send(self._create_event_data(event_message))
299+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
300+
break
301+
except anyio.ClosedResourceError: # pragma: lax no cover
302+
logger.debug("SSE stream closed by close_sse_stream()")
303+
except Exception: # pragma: lax no cover
304+
logger.exception("Error in SSE writer")
305+
finally:
306+
logger.debug("Closing SSE writer")
307+
self._sse_stream_writers.pop(request_id, None)
308+
await self._clean_up_memory_streams(request_id)
284309

285310
def _create_error_response(
286311
self,
@@ -334,7 +359,7 @@ def _get_session_id(self, request: Request) -> str | None:
334359
"""Extract the session ID from request headers."""
335360
return request.headers.get(MCP_SESSION_ID_HEADER)
336361

337-
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
362+
def _create_event_data(self, event_message: EventMessage) -> SSEEvent:
338363
"""Create event data dictionary from an EventMessage."""
339364
event_data = {
340365
"event": "message",
@@ -521,13 +546,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
521546
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
522547
)
523548

524-
# Extract the request ID outside the try block for proper scope
525549
request_id = str(message.id)
526-
# Register this stream for the request ID
527-
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
528-
request_stream_reader = self._request_streams[request_id][1]
529550

530551
if self.is_json_response_enabled:
552+
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
553+
REQUEST_STREAM_BUFFER_SIZE
554+
)
555+
request_stream_reader = self._request_streams[request_id][1]
531556
# Process the message
532557
metadata = ServerMessageMetadata(request_context=request)
533558
session_message = SessionMessage(message, metadata=metadata)
@@ -571,41 +596,19 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
571596
finally:
572597
await self._clean_up_memory_streams(request_id)
573598
else:
574-
# Create SSE stream
575-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
599+
# Mint the priming event before any per-request state exists:
600+
# `EventStore.store_event` is user code and may raise, in which
601+
# case the outer handler returns a 500 with nothing to clean up.
602+
# Still strictly precedes dispatch, so storage order == wire order.
603+
priming_event = await self._mint_priming_event(request_id, protocol_version)
576604

577-
# Store writer reference so close_sse_stream() can close it
605+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
578606
self._sse_stream_writers[request_id] = sse_stream_writer
607+
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
608+
REQUEST_STREAM_BUFFER_SIZE
609+
)
610+
request_stream_reader = self._request_streams[request_id][1]
579611

580-
async def sse_writer():
581-
# Get the request ID from the incoming request message
582-
try:
583-
async with sse_stream_writer, request_stream_reader:
584-
# Send priming event for SSE resumability
585-
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)
586-
587-
# Process messages from the request-specific stream
588-
async for event_message in request_stream_reader:
589-
# Build the event data
590-
event_data = self._create_event_data(event_message)
591-
await sse_stream_writer.send(event_data)
592-
593-
# If response, remove from pending streams and close
594-
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
595-
break
596-
except anyio.ClosedResourceError: # pragma: lax no cover
597-
# Expected when close_sse_stream() is called
598-
logger.debug("SSE stream closed by close_sse_stream()")
599-
except Exception: # pragma: lax no cover
600-
logger.exception("Error in SSE writer")
601-
finally:
602-
logger.debug("Closing SSE writer")
603-
self._sse_stream_writers.pop(request_id, None)
604-
await self._clean_up_memory_streams(request_id)
605-
606-
# Create and start EventSourceResponse
607-
# SSE stream mode (original behavior)
608-
# Set up headers
609612
headers = {
610613
"Cache-Control": "no-cache, no-transform",
611614
"Connection": "keep-alive",
@@ -614,7 +617,9 @@ async def sse_writer():
614617
}
615618
response = EventSourceResponse(
616619
content=sse_stream_reader,
617-
data_sender_callable=sse_writer,
620+
data_sender_callable=partial(
621+
self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event
622+
),
618623
headers=headers,
619624
)
620625

@@ -633,20 +638,16 @@ async def sse_writer():
633638
finally:
634639
await sse_stream_reader.aclose()
635640

636-
except Exception as err: # pragma: lax no cover
637-
# Reached only when something raises during POST handling outside
638-
# the per-SSE-stream guard above; whether tests reach this depends
639-
# on client teardown timing.
641+
except Exception as err:
640642
logger.exception("Error handling POST request")
641643
response = self._create_error_response(
642-
f"Error handling POST request: {err}",
644+
"Error handling POST request",
643645
HTTPStatus.INTERNAL_SERVER_ERROR,
644646
INTERNAL_ERROR,
645647
)
646648
await response(scope, receive, send)
647-
if writer:
648-
await writer.send(Exception(err))
649-
return # pragma: no cover
649+
await writer.send(Exception(err))
650+
return
650651

651652
async def _handle_get_request(self, request: Request, send: Send) -> None:
652653
"""Handle GET request to establish SSE.
@@ -697,13 +698,15 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
697698
return
698699

699700
# Create SSE stream
700-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
701+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
701702

702703
async def standalone_sse_writer():
703704
try:
704705
# Create a standalone message stream for server-initiated messages
705706

706-
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
707+
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](
708+
REQUEST_STREAM_BUFFER_SIZE
709+
)
707710
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
708711

709712
async with sse_stream_writer, standalone_stream_reader:
@@ -871,7 +874,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
871874
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
872875

873876
# Create SSE stream for replay
874-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
877+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
875878

876879
async def replay_sender():
877880
try:
@@ -886,22 +889,32 @@ async def send_event(event_message: EventMessage) -> None:
886889

887890
# If stream ID not in mapping, create it
888891
if stream_id and stream_id not in self._request_streams: # pragma: no branch
889-
# Register SSE writer so close_sse_stream() can close it
890-
self._sse_stream_writers[stream_id] = sse_stream_writer
891-
892-
# Send priming event for this new connection
893-
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)
894-
895-
# Create new request streams for this connection
896-
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
897-
msg_reader = self._request_streams[stream_id][1]
898-
899-
# Forward messages to SSE
900-
async with msg_reader:
901-
async for event_message in msg_reader:
902-
event_data = self._create_event_data(event_message)
903-
904-
await sse_stream_writer.send(event_data)
892+
try:
893+
# Register SSE writer so close_sse_stream() can close it
894+
self._sse_stream_writers[stream_id] = sse_stream_writer
895+
896+
# Prime the resumed connection so the client sees the stream
897+
# is re-registered. The replay→live-tail ordering window here
898+
# is pre-existing and tracked separately.
899+
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
900+
if priming_event is not None:
901+
await sse_stream_writer.send(priming_event)
902+
903+
# Create new request streams for this connection
904+
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](
905+
REQUEST_STREAM_BUFFER_SIZE
906+
)
907+
msg_reader = self._request_streams[stream_id][1]
908+
909+
# Forward messages to SSE
910+
async with msg_reader:
911+
async for event_message in msg_reader:
912+
event_data = self._create_event_data(event_message)
913+
914+
await sse_stream_writer.send(event_data)
915+
finally:
916+
self._sse_stream_writers.pop(stream_id, None)
917+
await self._clean_up_memory_streams(stream_id)
905918
except anyio.ClosedResourceError: # pragma: lax no cover
906919
# Expected when close_sse_stream() is called
907920
logger.debug("Replay SSE stream closed by close_sse_stream()")

tests/interaction/transports/test_hosting_resume.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,43 @@ async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_ev
113113
)
114114

115115

116+
@requirement("hosting:resume:priming")
117+
async def test_the_priming_row_is_stored_before_any_handler_output_for_that_stream() -> None:
118+
"""The priming cursor is the first row the event store records for a request's stream.
119+
120+
The POST handler stores the priming row before dispatching the request, so by construction
121+
it precedes anything `message_router` can store for that stream id.
122+
"""
123+
store = SequencedEventStore()
124+
mcp = MCPServer("resumable")
125+
126+
@mcp.tool()
127+
async def burst(ctx: Context) -> str:
128+
await ctx.info("a") # pyright: ignore[reportDeprecated]
129+
await ctx.info("b") # pyright: ignore[reportDeprecated]
130+
await ctx.info("c") # pyright: ignore[reportDeprecated]
131+
return "done"
132+
133+
async with mounted_app(mcp, event_store=store) as (http, _):
134+
session_id = await initialize_via_http(http)
135+
with anyio.fail_after(5):
136+
async with http.stream( # pragma: no branch
137+
"POST", "/mcp", content=_tools_call(2, "burst", {}), headers=base_headers(session_id=session_id)
138+
) as response:
139+
await _read_events(response, 5)
140+
141+
# initialize wrote two rows (its own priming + response); everything after is this call.
142+
call_rows = store._events[2:]
143+
stream_id = call_rows[0][0]
144+
assert [(s, None if m is None else type(m).__name__) for s, m in call_rows] == [
145+
(stream_id, None),
146+
(stream_id, "JSONRPCNotification"),
147+
(stream_id, "JSONRPCNotification"),
148+
(stream_id, "JSONRPCNotification"),
149+
(stream_id, "JSONRPCResponse"),
150+
]
151+
152+
116153
@requirement("hosting:resume:replay")
117154
@requirement("hosting:resume:stream-scoped")
118155
@requirement("hosting:resume:buffered-replay")
@@ -182,6 +219,46 @@ async def count(ctx: Context) -> str:
182219
)
183220

184221

222+
@requirement("hosting:resume:priming")
223+
async def test_a_pre_2025_11_25_reconnect_replays_without_minting_a_priming_event() -> None:
224+
"""A pre-2025-11-25 client reconnecting via Last-Event-ID gets the replay with no priming row.
225+
226+
The store-length assertion is the load-bearing proof that no priming cursor was minted.
227+
"""
228+
release = anyio.Event()
229+
store = SequencedEventStore()
230+
mcp = MCPServer("resumable")
231+
232+
@mcp.tool()
233+
async def count(ctx: Context) -> str:
234+
await ctx.info("tick 1") # pyright: ignore[reportDeprecated]
235+
await release.wait()
236+
await ctx.info("tick 2") # pyright: ignore[reportDeprecated]
237+
return "counted"
238+
239+
async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _):
240+
session_id = await initialize_via_http(http)
241+
with anyio.fail_after(5):
242+
async with http.stream(
243+
"POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id)
244+
) as response:
245+
_, first = await _read_events(response, 2)
246+
release.set()
247+
await store.wait_until_stored(6)
248+
old_client_headers = base_headers(session_id=session_id) | {
249+
"mcp-protocol-version": "2025-06-18",
250+
"last-event-id": first.id,
251+
}
252+
async with http.stream("GET", "/mcp", headers=old_client_headers) as replay: # pragma: no branch
253+
assert replay.status_code == 200
254+
missed = await _read_events(replay, 2)
255+
256+
assert [(event.id, bool(event.data)) for event in missed] == snapshot([("5", True), ("6", True)])
257+
# No priming cursor was minted on reconnect: the store still holds only the six rows
258+
# written before the GET (init priming+response, POST priming, tick 1, tick 2, result).
259+
assert len(store._events) == 6
260+
261+
185262
@requirement("hosting:resume:bad-event-id")
186263
async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None:
187264
"""A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error.

0 commit comments

Comments
 (0)