Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from __future__ import annotations

import asyncio
import contextlib
import copy
import logging
from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from inspect import isawaitable
from typing import Any

Expand All @@ -29,6 +31,72 @@

logger = logging.getLogger(__name__)

# Default idle interval after which a transport-level SSE keepalive comment is emitted.
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 15.0

# An SSE comment line (a line beginning with ``:``) is a protocol no-op: clients and
# parsers ignore it, but it forces a write on the wire so idle-timeout proxies in front
# of the endpoint keep the connection open during long silent gaps between events.
_SSE_KEEPALIVE_COMMENT = ": keepalive\n\n"


class _StreamEnd:
"""Sentinel marking normal completion of the upstream event stream."""


async def _with_sse_keepalive(
events: AsyncIterator[str],
interval_seconds: float,
) -> AsyncGenerator[str]:
"""Yield upstream SSE strings, inserting keepalive comments during idle gaps.

Real events flush immediately and in order; a keepalive comment is emitted only when no
upstream event arrives within ``interval_seconds``, so the heartbeat is limited to idle
periods and never delays or reorders genuine events.

The upstream generator is drained by a single dedicated task that owns its entire lifecycle
(creation, iteration, and cleanup). This keeps every ``__anext__`` and the terminal cleanup
hooks in one ``contextvars`` context, which the agent run/telemetry pipeline requires:
those hooks reset ``ContextVar`` tokens that must be reset in the same context that created
them. Racing each individual pull with ``asyncio.wait_for`` would instead scatter the pulls
across contexts and break that cleanup.
"""
queue: asyncio.Queue[str | type[_StreamEnd] | Exception] = asyncio.Queue()

async def _drain() -> None:
try:
async for event in events:
await queue.put(event)
except asyncio.CancelledError:
# Consumer went away; let cancellation propagate so the upstream generator closes.
raise
except Exception as exc: # noqa: BLE001 - surfaced to the consumer via the queue
await queue.put(exc)
else:
await queue.put(_StreamEnd)

producer = asyncio.ensure_future(_drain())
try:
while True:
try:
item = await asyncio.wait_for(queue.get(), interval_seconds)
except asyncio.TimeoutError:
# Upstream produced nothing within the interval; emit a transport-level heartbeat.
yield _SSE_KEEPALIVE_COMMENT
continue
if item is _StreamEnd:
return
if isinstance(item, BaseException):
raise item
yield item # type: ignore[misc]
finally:
# Stop draining if the consumer goes away (e.g. client disconnect) and let the
# producer task observe cancellation so the upstream generator is closed cleanly.
if not producer.done():
producer.cancel()
with contextlib.suppress(asyncio.CancelledError):
await producer


def _get_snapshot_store(
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
Expand Down Expand Up @@ -82,6 +150,7 @@ def add_agent_framework_fastapi_endpoint(
dependencies: Sequence[Depends] | None = None,
snapshot_store: AGUIThreadSnapshotStore | None = None,
snapshot_scope_resolver: SnapshotScopeResolver | None = None,
keepalive_interval_seconds: float | None = DEFAULT_KEEPALIVE_INTERVAL_SECONDS,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.

Expand All @@ -103,7 +172,14 @@ def add_agent_framework_fastapi_endpoint(
explicit Snapshot Scope resolver.
snapshot_scope_resolver: Optional resolver for the application-defined Snapshot Scope. Required whenever
a snapshot store is configured because an AG-UI Thread id is not an authorization boundary.
keepalive_interval_seconds: Idle interval (in seconds) after which a transport-level SSE keepalive
comment (``: keepalive``) is written to the stream when no agent event has been produced.
This prevents idle-timeout proxies (e.g. Azure ingress, nginx, serverless front doors) from
dropping a healthy but silent event stream during long-running tools. Real events still flush
immediately. Defaults to ``15.0``; pass ``None`` to disable keepalives entirely.
"""
if keepalive_interval_seconds is not None and keepalive_interval_seconds <= 0:
raise ValueError("keepalive_interval_seconds must be a positive number or None.")
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow
if isinstance(agent, AgentFrameworkWorkflow):
protocol_runner = agent
Expand Down Expand Up @@ -208,8 +284,12 @@ async def event_generator() -> AsyncGenerator[str]:
except Exception:
logger.exception("[%s] Failed to encode RUN_ERROR event", path)

stream: AsyncGenerator[str] = event_generator()
if keepalive_interval_seconds is not None:
stream = _with_sse_keepalive(stream, keepalive_interval_seconds)

return StreamingResponse(
event_generator(),
stream,
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
Expand Down
117 changes: 117 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

"""Tests for FastAPI endpoint creation (_endpoint.py)."""

import asyncio
import json
from collections.abc import AsyncIterator
from typing import Any, cast

import pytest
Expand All @@ -22,6 +24,7 @@

from agent_framework_ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint
from agent_framework_ag_ui._agent import AgentFrameworkAgent
from agent_framework_ag_ui._endpoint import _SSE_KEEPALIVE_COMMENT, _with_sse_keepalive
from agent_framework_ag_ui._workflow import AgentFrameworkWorkflow


Expand Down Expand Up @@ -1844,3 +1847,117 @@ def factory(thread_id: str) -> Any:

runner.clear_thread_workflow("thread-1")
assert runner._resolve_workflow("thread-1", "tenant-b") is not workflow_b


async def test_sse_keepalive_emitted_during_idle_gap_and_real_events_pass_through():
"""A silent gap between upstream events yields keepalive comments without dropping real events."""
first_released = asyncio.Event()

async def upstream() -> AsyncIterator[str]:
# Emit immediately, then go silent until the test explicitly releases the next event,
# simulating a long-running tool that produces no AG-UI events for a while.
yield "data: A\n\n"
await first_released.wait()
yield "data: B\n\n"

# Use a tiny interval so the idle gap reliably trips several keepalives without slow tests.
wrapped = _with_sse_keepalive(upstream(), 0.01)

Comment on lines +1863 to +1865
chunks: list[str] = []
chunks.append(await wrapped.__anext__()) # real event A flushes immediately

# While upstream is idle, the wrapper must surface keepalive comments.
keepalive = await wrapped.__anext__()
assert keepalive == _SSE_KEEPALIVE_COMMENT
chunks.append(keepalive)

# Release the second real event and drain the rest of the stream.
first_released.set()
async for chunk in wrapped:
chunks.append(chunk)

data_lines = [chunk for chunk in chunks if chunk.startswith("data: ")]
assert data_lines == ["data: A\n\n", "data: B\n\n"]
assert _SSE_KEEPALIVE_COMMENT in chunks


async def test_sse_keepalive_not_emitted_when_events_flow_without_gaps():
"""Back-to-back events must pass through untouched with no keepalive comments inserted."""

async def upstream() -> AsyncIterator[str]:
yield "data: X\n\n"
yield "data: Y\n\n"
yield "data: Z\n\n"

chunks = [chunk async for chunk in _with_sse_keepalive(upstream(), 0.05)]

assert chunks == ["data: X\n\n", "data: Y\n\n", "data: Z\n\n"]
assert _SSE_KEEPALIVE_COMMENT not in chunks


async def test_sse_keepalive_wrapper_handles_empty_stream():
"""An upstream that yields nothing terminates cleanly with no keepalives."""

async def upstream() -> AsyncIterator[str]:
return
yield # pragma: no cover - present only to make this an async generator

chunks = [chunk async for chunk in _with_sse_keepalive(upstream(), 0.01)]

assert chunks == []


async def test_sse_keepalive_wrapper_propagates_upstream_errors():
"""Errors raised by the upstream generator surface to the consumer rather than hanging."""

async def upstream() -> AsyncIterator[str]:
yield "data: A\n\n"
raise RuntimeError("boom")

wrapped = _with_sse_keepalive(upstream(), 0.05)
assert await wrapped.__anext__() == "data: A\n\n"
with pytest.raises(RuntimeError, match="boom"):
await wrapped.__anext__()


async def test_endpoint_accepts_keepalive_interval_and_streams_events(build_chat_client):
"""Endpoint accepts keepalive_interval_seconds and still streams the full event sequence."""
app = FastAPI()
agent = Agent(name="test", instructions="Test agent", client=build_chat_client("Keepalive response"))

add_agent_framework_fastapi_endpoint(app, agent, path="/keepalive", keepalive_interval_seconds=0.05)

client = TestClient(app)
response = client.post("/keepalive", json={"messages": [{"role": "user", "content": "Hello"}]})

assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"

event_types = [event.get("type") for event in _decode_sse_events(response)]
assert "RUN_STARTED" in event_types
assert "TEXT_MESSAGE_CONTENT" in event_types
assert "RUN_FINISHED" in event_types


async def test_endpoint_keepalive_can_be_disabled(build_chat_client):
"""Passing keepalive_interval_seconds=None keeps the plain stream behavior."""
app = FastAPI()
agent = Agent(name="test", instructions="Test agent", client=build_chat_client())

add_agent_framework_fastapi_endpoint(app, agent, path="/no-keepalive", keepalive_interval_seconds=None)

client = TestClient(app)
response = client.post("/no-keepalive", json={"messages": [{"role": "user", "content": "Hello"}]})

assert response.status_code == 200
# No keepalive comment lines should appear when the feature is disabled.
assert ": keepalive" not in response.content.decode("utf-8")


async def test_endpoint_rejects_non_positive_keepalive_interval(build_chat_client):
"""A non-positive keepalive interval is rejected at registration time."""
app = FastAPI()
agent = Agent(name="test", instructions="Test agent", client=build_chat_client())

with pytest.raises(ValueError, match="keepalive_interval_seconds must be a positive number or None"):
add_agent_framework_fastapi_endpoint(app, agent, path="/bad-keepalive", keepalive_interval_seconds=0)