From e53fe677fe16bc58d735a0dd32fa814e902c682c Mon Sep 17 00:00:00 2001 From: vaibhav-patel Date: Sat, 20 Jun 2026 15:46:07 +0530 Subject: [PATCH] Python: Add SSE keepalive interval to AG-UI FastAPI endpoint The AG-UI FastAPI endpoint served the event stream as a bare StreamingResponse with no application-level heartbeat. During long silent gaps between agent events (slow server-side tools such as OCR or retrieval), idle-timeout proxies in front of the endpoint (Azure ingress, nginx, serverless front doors) drop the otherwise-healthy connection and the client sees a spurious HTTP 500. Wrap the event-stream generator so that when no upstream event is produced within a configurable interval, a transport-level SSE keepalive comment (": keepalive") is written to the wire. SSE comment lines are a protocol no-op that clients and parsers ignore, so real events still flush immediately and in order while only idle gaps trigger a ping. The interval is exposed via the new keepalive_interval_seconds parameter on add_agent_framework_fastapi_endpoint (default 15.0; pass None to disable). The upstream generator is drained by a single dedicated task so its entire lifecycle stays in one contextvars context, which the agent run/telemetry pipeline requires for its ContextVar cleanup hooks. Fixes #6611. --- .../ag-ui/agent_framework_ag_ui/_endpoint.py | 84 ++++++++++++- .../ag-ui/tests/ag_ui/test_endpoint.py | 117 ++++++++++++++++++ 2 files changed, 199 insertions(+), 2 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index 1d04964ce67..9899f115c12 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -4,9 +4,11 @@ from __future__ import annotations +import asyncio +import contextlib import copy import logging -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from inspect import isawaitable from typing import Any @@ -29,6 +31,72 @@ logger = logging.getLogger(__name__) +# Default idle interval after which a transport-level SSE keepalive comment is emitted. +DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 15.0 + +# An SSE comment line (a line beginning with ``:``) is a protocol no-op: clients and +# parsers ignore it, but it forces a write on the wire so idle-timeout proxies in front +# of the endpoint keep the connection open during long silent gaps between events. +_SSE_KEEPALIVE_COMMENT = ": keepalive\n\n" + + +class _StreamEnd: + """Sentinel marking normal completion of the upstream event stream.""" + + +async def _with_sse_keepalive( + events: AsyncIterator[str], + interval_seconds: float, +) -> AsyncGenerator[str]: + """Yield upstream SSE strings, inserting keepalive comments during idle gaps. + + Real events flush immediately and in order; a keepalive comment is emitted only when no + upstream event arrives within ``interval_seconds``, so the heartbeat is limited to idle + periods and never delays or reorders genuine events. + + The upstream generator is drained by a single dedicated task that owns its entire lifecycle + (creation, iteration, and cleanup). This keeps every ``__anext__`` and the terminal cleanup + hooks in one ``contextvars`` context, which the agent run/telemetry pipeline requires: + those hooks reset ``ContextVar`` tokens that must be reset in the same context that created + them. Racing each individual pull with ``asyncio.wait_for`` would instead scatter the pulls + across contexts and break that cleanup. + """ + queue: asyncio.Queue[str | type[_StreamEnd] | Exception] = asyncio.Queue() + + async def _drain() -> None: + try: + async for event in events: + await queue.put(event) + except asyncio.CancelledError: + # Consumer went away; let cancellation propagate so the upstream generator closes. + raise + except Exception as exc: # noqa: BLE001 - surfaced to the consumer via the queue + await queue.put(exc) + else: + await queue.put(_StreamEnd) + + producer = asyncio.ensure_future(_drain()) + try: + while True: + try: + item = await asyncio.wait_for(queue.get(), interval_seconds) + except asyncio.TimeoutError: + # Upstream produced nothing within the interval; emit a transport-level heartbeat. + yield _SSE_KEEPALIVE_COMMENT + continue + if item is _StreamEnd: + return + if isinstance(item, BaseException): + raise item + yield item # type: ignore[misc] + finally: + # Stop draining if the consumer goes away (e.g. client disconnect) and let the + # producer task observe cancellation so the upstream generator is closed cleanly. + if not producer.done(): + producer.cancel() + with contextlib.suppress(asyncio.CancelledError): + await producer + def _get_snapshot_store( protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow, @@ -82,6 +150,7 @@ def add_agent_framework_fastapi_endpoint( dependencies: Sequence[Depends] | None = None, snapshot_store: AGUIThreadSnapshotStore | None = None, snapshot_scope_resolver: SnapshotScopeResolver | None = None, + keepalive_interval_seconds: float | None = DEFAULT_KEEPALIVE_INTERVAL_SECONDS, ) -> None: """Add an AG-UI endpoint to a FastAPI app. @@ -103,7 +172,14 @@ def add_agent_framework_fastapi_endpoint( explicit Snapshot Scope resolver. snapshot_scope_resolver: Optional resolver for the application-defined Snapshot Scope. Required whenever a snapshot store is configured because an AG-UI Thread id is not an authorization boundary. + keepalive_interval_seconds: Idle interval (in seconds) after which a transport-level SSE keepalive + comment (``: keepalive``) is written to the stream when no agent event has been produced. + This prevents idle-timeout proxies (e.g. Azure ingress, nginx, serverless front doors) from + dropping a healthy but silent event stream during long-running tools. Real events still flush + immediately. Defaults to ``15.0``; pass ``None`` to disable keepalives entirely. """ + if keepalive_interval_seconds is not None and keepalive_interval_seconds <= 0: + raise ValueError("keepalive_interval_seconds must be a positive number or None.") protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow if isinstance(agent, AgentFrameworkWorkflow): protocol_runner = agent @@ -208,8 +284,12 @@ async def event_generator() -> AsyncGenerator[str]: except Exception: logger.exception("[%s] Failed to encode RUN_ERROR event", path) + stream: AsyncGenerator[str] = event_generator() + if keepalive_interval_seconds is not None: + stream = _with_sse_keepalive(stream, keepalive_interval_seconds) + return StreamingResponse( - event_generator(), + stream, media_type="text/event-stream", headers={ "Cache-Control": "no-cache", diff --git a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py index f675992e1b1..a5d7dd07f7b 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py +++ b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py @@ -2,7 +2,9 @@ """Tests for FastAPI endpoint creation (_endpoint.py).""" +import asyncio import json +from collections.abc import AsyncIterator from typing import Any, cast import pytest @@ -22,6 +24,7 @@ from agent_framework_ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint from agent_framework_ag_ui._agent import AgentFrameworkAgent +from agent_framework_ag_ui._endpoint import _SSE_KEEPALIVE_COMMENT, _with_sse_keepalive from agent_framework_ag_ui._workflow import AgentFrameworkWorkflow @@ -1844,3 +1847,117 @@ def factory(thread_id: str) -> Any: runner.clear_thread_workflow("thread-1") assert runner._resolve_workflow("thread-1", "tenant-b") is not workflow_b + + +async def test_sse_keepalive_emitted_during_idle_gap_and_real_events_pass_through(): + """A silent gap between upstream events yields keepalive comments without dropping real events.""" + first_released = asyncio.Event() + + async def upstream() -> AsyncIterator[str]: + # Emit immediately, then go silent until the test explicitly releases the next event, + # simulating a long-running tool that produces no AG-UI events for a while. + yield "data: A\n\n" + await first_released.wait() + yield "data: B\n\n" + + # Use a tiny interval so the idle gap reliably trips several keepalives without slow tests. + wrapped = _with_sse_keepalive(upstream(), 0.01) + + chunks: list[str] = [] + chunks.append(await wrapped.__anext__()) # real event A flushes immediately + + # While upstream is idle, the wrapper must surface keepalive comments. + keepalive = await wrapped.__anext__() + assert keepalive == _SSE_KEEPALIVE_COMMENT + chunks.append(keepalive) + + # Release the second real event and drain the rest of the stream. + first_released.set() + async for chunk in wrapped: + chunks.append(chunk) + + data_lines = [chunk for chunk in chunks if chunk.startswith("data: ")] + assert data_lines == ["data: A\n\n", "data: B\n\n"] + assert _SSE_KEEPALIVE_COMMENT in chunks + + +async def test_sse_keepalive_not_emitted_when_events_flow_without_gaps(): + """Back-to-back events must pass through untouched with no keepalive comments inserted.""" + + async def upstream() -> AsyncIterator[str]: + yield "data: X\n\n" + yield "data: Y\n\n" + yield "data: Z\n\n" + + chunks = [chunk async for chunk in _with_sse_keepalive(upstream(), 0.05)] + + assert chunks == ["data: X\n\n", "data: Y\n\n", "data: Z\n\n"] + assert _SSE_KEEPALIVE_COMMENT not in chunks + + +async def test_sse_keepalive_wrapper_handles_empty_stream(): + """An upstream that yields nothing terminates cleanly with no keepalives.""" + + async def upstream() -> AsyncIterator[str]: + return + yield # pragma: no cover - present only to make this an async generator + + chunks = [chunk async for chunk in _with_sse_keepalive(upstream(), 0.01)] + + assert chunks == [] + + +async def test_sse_keepalive_wrapper_propagates_upstream_errors(): + """Errors raised by the upstream generator surface to the consumer rather than hanging.""" + + async def upstream() -> AsyncIterator[str]: + yield "data: A\n\n" + raise RuntimeError("boom") + + wrapped = _with_sse_keepalive(upstream(), 0.05) + assert await wrapped.__anext__() == "data: A\n\n" + with pytest.raises(RuntimeError, match="boom"): + await wrapped.__anext__() + + +async def test_endpoint_accepts_keepalive_interval_and_streams_events(build_chat_client): + """Endpoint accepts keepalive_interval_seconds and still streams the full event sequence.""" + app = FastAPI() + agent = Agent(name="test", instructions="Test agent", client=build_chat_client("Keepalive response")) + + add_agent_framework_fastapi_endpoint(app, agent, path="/keepalive", keepalive_interval_seconds=0.05) + + client = TestClient(app) + response = client.post("/keepalive", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + event_types = [event.get("type") for event in _decode_sse_events(response)] + assert "RUN_STARTED" in event_types + assert "TEXT_MESSAGE_CONTENT" in event_types + assert "RUN_FINISHED" in event_types + + +async def test_endpoint_keepalive_can_be_disabled(build_chat_client): + """Passing keepalive_interval_seconds=None keeps the plain stream behavior.""" + app = FastAPI() + agent = Agent(name="test", instructions="Test agent", client=build_chat_client()) + + add_agent_framework_fastapi_endpoint(app, agent, path="/no-keepalive", keepalive_interval_seconds=None) + + client = TestClient(app) + response = client.post("/no-keepalive", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + # No keepalive comment lines should appear when the feature is disabled. + assert ": keepalive" not in response.content.decode("utf-8") + + +async def test_endpoint_rejects_non_positive_keepalive_interval(build_chat_client): + """A non-positive keepalive interval is rejected at registration time.""" + app = FastAPI() + agent = Agent(name="test", instructions="Test agent", client=build_chat_client()) + + with pytest.raises(ValueError, match="keepalive_interval_seconds must be a positive number or None"): + add_agent_framework_fastapi_endpoint(app, agent, path="/bad-keepalive", keepalive_interval_seconds=0)