Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions python/packages/ag-ui/agent_framework_ag_ui/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ToolCallResultEvent,
ToolCallStartEvent,
)
from agent_framework import Workflow
from agent_framework import CheckpointStorage, Workflow

from ._message_adapters import agui_messages_to_snapshot_format
from ._run_common import (
Expand All @@ -45,6 +45,13 @@

WorkflowFactory = Callable[[str], Workflow]

# Input-payload keys used to surface workflow checkpointing through ``run(input_data)``
# without changing the positional call convention used by the FastAPI endpoint. The
# corresponding ``run()`` keyword arguments take precedence over these when both are
# supplied.
_CHECKPOINT_ID_INPUT_KEY = "__ag_ui_checkpoint_id"
_CHECKPOINT_STORAGE_INPUT_KEY = "__ag_ui_checkpoint_storage"


def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]:
"""Convert AG-UI message event models to plain snapshot dictionaries."""
Expand Down Expand Up @@ -276,10 +283,38 @@ def clear_workflow_cache(self) -> None:
"""Drop all cached thread workflow instances."""
self._workflow_by_thread.clear()

async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]:
async def run(
self,
input_data: dict[str, Any],
*,
checkpoint_storage: CheckpointStorage | None = None,
checkpoint_id: str | None = None,
) -> AsyncGenerator[BaseEvent]:
"""Run the wrapped workflow and yield AG-UI events.

Args:
input_data: The AG-UI request payload (a ``RunAgentInput`` dump).
checkpoint_storage: Optional checkpoint storage to enable workflow
checkpointing for this run. When provided, the underlying core
workflow creates a checkpoint at the end of each superstep, matching
``agent_framework.Workflow.run(checkpoint_storage=...)``. May also be
supplied via the ``input_data`` key ``__ag_ui_checkpoint_storage``;
the keyword argument takes precedence.
checkpoint_id: Optional checkpoint id to resume the workflow from. When
provided, execution restores the persisted workflow state instead of
starting a fresh turn, matching
``agent_framework.Workflow.run(checkpoint_id=...)``. May also be
supplied via the ``input_data`` key ``__ag_ui_checkpoint_id``; the
keyword argument takes precedence.

Subclasses may override this to provide custom AG-UI streams.

Note:
Checkpointing (the ``agent_framework`` workflow checkpoint mechanism) is
independent from AG-UI Thread Snapshot persistence (``snapshot_store``).
The two can be used together, but they persist different things: snapshots
capture replayable protocol output for a thread, while checkpoints capture
executor/runtime state for resumable execution.
"""
thread_id = self._thread_id_from_input(input_data)
run_id = str(input_data.get("run_id") or input_data.get("runId") or uuid.uuid4())
Expand All @@ -288,7 +323,23 @@ async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]:
resume_payload = _extract_resume_payload(input_data)
snapshot_store = self.snapshot_store

if snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None:
# Explicit keyword arguments win over values smuggled through input_data so the
# FastAPI endpoint (which calls ``run(input_data)`` positionally) can still opt
# into checkpointing without changing its call site.
if checkpoint_id is None:
checkpoint_id = cast(str | None, input_data.get(_CHECKPOINT_ID_INPUT_KEY))
if checkpoint_storage is None:
checkpoint_storage = cast(CheckpointStorage | None, input_data.get(_CHECKPOINT_STORAGE_INPUT_KEY))

# A checkpoint resume legitimately carries no new messages; it must reach the
# core workflow's restore path rather than replaying a stored thread snapshot.
if (
checkpoint_id is None
and snapshot_store is not None
and snapshot_scope is not None
and not raw_messages
and resume_payload is None
):
async for event in _hydrate_workflow_thread_snapshot(
snapshot_store=snapshot_store,
scope=snapshot_scope,
Expand Down Expand Up @@ -346,8 +397,17 @@ async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]:
state_snapshot = make_json_safe(effective_state)
if isinstance(state_snapshot, dict):
snapshot_builder.state = cast(dict[str, Any], state_snapshot)
# Only forward checkpoint kwargs when checkpointing is requested so the
# non-checkpoint path keeps calling ``run_workflow_stream(input_data, workflow)``
# exactly as before (preserves the established two-argument call convention).
stream_kwargs: dict[str, Any] = {}
if checkpoint_storage is not None:
stream_kwargs["checkpoint_storage"] = checkpoint_storage
if checkpoint_id is not None:
stream_kwargs["checkpoint_id"] = checkpoint_id

run_error_emitted = False
async for event in run_workflow_stream(input_data, workflow):
async for event in run_workflow_stream(input_data, workflow, **stream_kwargs):
if snapshot_builder is not None:
snapshot_builder.observe(event)
if isinstance(event, RunErrorEvent):
Expand Down
58 changes: 51 additions & 7 deletions python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
ToolCallEndEvent,
ToolCallStartEvent,
)
from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, Workflow, WorkflowRunState
from agent_framework import (
AgentResponse,
AgentResponseUpdate,
CheckpointStorage,
Content,
Message,
Workflow,
WorkflowRunState,
)

from ._message_adapters import normalize_agui_input_messages
from ._run_common import (
Expand Down Expand Up @@ -558,8 +566,24 @@ def _details_code(details: Any) -> str | None:
async def run_workflow_stream(
input_data: dict[str, Any],
workflow: Workflow,
*,
checkpoint_storage: CheckpointStorage | None = None,
checkpoint_id: str | None = None,
) -> AsyncGenerator[BaseEvent]:
"""Run a Workflow and emit AG-UI protocol events."""
"""Run a Workflow and emit AG-UI protocol events.

Args:
input_data: Normalized AG-UI request payload (a ``RunAgentInput`` dump).
workflow: The core ``Workflow`` instance to execute.
checkpoint_storage: Optional checkpoint storage forwarded to the core
workflow. When provided, the workflow creates a checkpoint at the end
of each superstep, mirroring ``Workflow.run(checkpoint_storage=...)``.
checkpoint_id: Optional checkpoint id to resume from. When provided the run
restores the persisted workflow state instead of starting a fresh turn,
mirroring ``Workflow.run(checkpoint_id=...)``. Any incoming messages are
treated as request-info responses (or ignored) rather than a new
start-executor message, so resume stays consistent with the core API.
"""
thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4())
run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4())
available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts")
Expand Down Expand Up @@ -587,7 +611,11 @@ async def run_workflow_stream(
if not responses and pending_before_run:
responses.update(_single_pending_response_from_value(pending_before_run, _latest_user_text(messages)))

if not responses and pending_before_run:
# A checkpoint resume must always reach ``workflow.run(checkpoint_id=...)`` so the
# core restores persisted state and re-emits any pending requests from the
# checkpoint. ``pending_before_run`` reflects the live (pre-restore) instance, so
# short-circuiting on it here would skip the restore entirely.
if checkpoint_id is None and not responses and pending_before_run:
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
for request_event in pending_before_run.values():
request_payload = _request_payload_from_request_event(request_event)
Expand All @@ -604,7 +632,7 @@ async def run_workflow_stream(
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=pending_interrupts)
return

if not responses and not messages:
if checkpoint_id is None and not responses and not messages:
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=pending_interrupts)
return
Expand Down Expand Up @@ -640,11 +668,27 @@ def _drain_open_message() -> list[TextMessageEndEvent]:
logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props")
fwd_kwargs = {}

# Forward checkpoint storage so the core workflow creates a checkpoint at the end
# of each superstep (parity with ``Workflow.run(checkpoint_storage=...)``).
checkpoint_kwargs: dict[str, Any] = {}
if checkpoint_storage is not None:
checkpoint_kwargs["checkpoint_storage"] = checkpoint_storage

try:
if responses:
event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs)
if checkpoint_id is not None:
# Resume from a checkpoint. ``message`` is mutually exclusive with
# ``checkpoint_id`` in the core API, so incoming messages are only
# forwarded as request-info responses (``responses``), never as a new
# start-executor message. ``responses`` + ``checkpoint_id`` performs a
# restore-then-send in a single call.
run_kwargs: dict[str, Any] = {"checkpoint_id": checkpoint_id, **checkpoint_kwargs}
if responses:
run_kwargs["responses"] = responses
event_stream = workflow.run(stream=True, **run_kwargs, **fwd_kwargs)
elif responses:
event_stream = workflow.run(responses=responses, stream=True, **checkpoint_kwargs, **fwd_kwargs)
else:
event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs)
event_stream = workflow.run(message=messages, stream=True, **checkpoint_kwargs, **fwd_kwargs)

async for event in event_stream:
event_type = getattr(event, "type", None)
Expand Down
147 changes: 144 additions & 3 deletions python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@
from typing import Any, cast

import pytest
from agent_framework import Workflow, WorkflowBuilder, WorkflowContext, executor
from agent_framework import (
InMemoryCheckpointStorage,
Workflow,
WorkflowBuilder,
WorkflowContext,
executor,
handler,
)
from agent_framework._workflows._executor import Executor

Comment on lines +10 to 19
from agent_framework_ag_ui import AgentFrameworkWorkflow


async def _run(agent: AgentFrameworkWorkflow, payload: dict[str, Any]) -> list[Any]:
return [event async for event in agent.run(payload)]
async def _run(
agent: AgentFrameworkWorkflow,
payload: dict[str, Any],
**run_kwargs: Any,
) -> list[Any]:
return [event async for event in agent.run(payload, **run_kwargs)]


async def test_workflow_wrapper_rejects_workflow_and_factory_at_once() -> None:
Expand Down Expand Up @@ -110,3 +122,132 @@ async def test_workflow_wrapper_factory_return_type_is_validated() -> None:

with pytest.raises(TypeError, match="workflow_factory must return a Workflow instance"):
_ = [event async for event in agent.run({"thread_id": "thread-a", "messages": []})]


# region checkpointing


class _StartExecutor(Executor):
@handler
async def run(self, message: Any, ctx: WorkflowContext[str]) -> None:
del message
await ctx.send_message("hello", target_id="middle")


class _MiddleExecutor(Executor):
@handler
async def process(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(f"{message}-processed", target_id="finish")


class _FinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[Any, str]) -> None:
await ctx.yield_output(f"{message}-done")


def _build_multi_superstep_workflow(storage: InMemoryCheckpointStorage | None = None) -> Workflow:
"""Build a start -> middle -> finish workflow that creates a checkpoint per superstep."""
start = _StartExecutor(id="start")
middle = _MiddleExecutor(id="middle")
finish = _FinishExecutor(id="finish")
builder = WorkflowBuilder(max_iterations=10, start_executor=start)
if storage is not None:
builder = WorkflowBuilder(max_iterations=10, start_executor=start, checkpoint_storage=storage)
return builder.add_edge(start, middle).add_edge(middle, finish).build()


async def test_workflow_run_creates_checkpoints_via_storage_kwarg() -> None:
"""Passing checkpoint_storage to run() should create workflow checkpoints (parity with core)."""
storage = InMemoryCheckpointStorage()
workflow = _build_multi_superstep_workflow()
agent = AgentFrameworkWorkflow(workflow=workflow)

events = await _run(
agent,
{"thread_id": "thread-cp", "messages": [{"role": "user", "content": "start"}]},
checkpoint_storage=storage,
)

event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert "RUN_FINISHED" in event_types
assert "RUN_ERROR" not in event_types

checkpoints = await storage.list_checkpoints(workflow_name=workflow.name)
# One checkpoint per superstep boundary: at least the initial superstep plus follow-ups.
assert len(checkpoints) >= 2


async def test_workflow_run_resumes_from_checkpoint_id() -> None:
"""run(checkpoint_id=...) should restore persisted state and finish the workflow."""
storage = InMemoryCheckpointStorage()
workflow = _build_multi_superstep_workflow(storage)
agent = AgentFrameworkWorkflow(workflow=workflow)

# First run: execute to completion while checkpoints are written.
first_events = await _run(
agent,
{"thread_id": "thread-cp", "messages": [{"role": "user", "content": "start"}]},
)
assert "RUN_ERROR" not in [event.type for event in first_events]

checkpoints = sorted(
await storage.list_checkpoints(workflow_name=workflow.name),
key=lambda checkpoint: checkpoint.timestamp,
)
assert checkpoints, "expected the run to create at least one checkpoint"
# Resume from the earliest checkpoint so middle -> finish replays and re-produces output.
resume_checkpoint_id = checkpoints[0].checkpoint_id

# Resume on the same thread (same underlying workflow instance) from the checkpoint.
resumed_events = await _run(
agent,
{"thread_id": "thread-cp", "messages": []},
checkpoint_id=resume_checkpoint_id,
checkpoint_storage=storage,
)

resumed_types = [event.type for event in resumed_events]
assert "RUN_STARTED" in resumed_types
assert "RUN_FINISHED" in resumed_types
assert "RUN_ERROR" not in resumed_types

# The resumed run should reproduce the final assistant output ("hello-processed-done").
resumed_text = "".join(
getattr(event, "delta", "") for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"
)
assert "done" in resumed_text


async def test_workflow_run_reads_checkpoint_params_from_input_data() -> None:
"""Checkpoint params smuggled through input_data should be honored (endpoint call convention)."""
storage = InMemoryCheckpointStorage()
workflow = _build_multi_superstep_workflow()
agent = AgentFrameworkWorkflow(workflow=workflow)

events = await _run(
agent,
{
"thread_id": "thread-cp-input",
"messages": [{"role": "user", "content": "start"}],
"__ag_ui_checkpoint_storage": storage,
},
)

assert "RUN_ERROR" not in [event.type for event in events]
checkpoints = await storage.list_checkpoints(workflow_name=workflow.name)
assert len(checkpoints) >= 1


async def test_workflow_run_without_checkpointing_is_unchanged() -> None:
"""Existing run(input_data) calls keep working unchanged when no checkpoint args are given."""
workflow = _build_multi_superstep_workflow()
agent = AgentFrameworkWorkflow(workflow=workflow)

events = await _run(agent, {"thread_id": "thread-plain", "messages": [{"role": "user", "content": "start"}]})

event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert "RUN_FINISHED" in event_types
assert "RUN_ERROR" not in event_types