From f2c5ee7c018d6a786f955036473ccb1efe82e869 Mon Sep 17 00:00:00 2001 From: vaibhav-patel Date: Sat, 20 Jun 2026 15:49:15 +0530 Subject: [PATCH] Python: add checkpointing support to AgentFrameworkWorkflow.run() in ag-ui The ag-ui AgentFrameworkWorkflow.run() previously accepted only a RunAgentInput payload and exposed no way to use the core workflow's checkpointing/state-persistence, unlike the core agent-framework workflow implementations. This left ag-ui workflows without resumable execution. Add optional checkpoint_storage and checkpoint_id keyword arguments to run(), threaded through run_workflow_stream() into the core Workflow.run(). This delegates to the existing core capability instead of reinventing it and keeps the public surface consistent with Workflow.run(): - checkpoint_storage enables checkpoint creation at each superstep boundary. - checkpoint_id resumes a run from a persisted checkpoint; incoming messages are forwarded only as request-info responses (never as a new start-executor message) to honor the core's message/checkpoint_id mutual exclusivity, and responses + checkpoint_id performs a restore-then-send in one call. Both can also be supplied via the input_data keys __ag_ui_checkpoint_storage and __ag_ui_checkpoint_id so the FastAPI endpoint (which calls run(input_data) positionally) can opt in without changing its call site; explicit keyword arguments take precedence. Checkpoint resume bypasses the AG-UI thread snapshot hydration early-returns so it always reaches the core restore path. Backward compatible: run(input_data) keeps working unchanged, and the non-checkpoint path still calls run_workflow_stream(input_data, workflow) with its original two-argument convention. Adds focused tests covering checkpoint creation, resume-from-checkpoint, input-data-keyed params, and the unchanged default path. Fixes #6632. --- .../ag-ui/agent_framework_ag_ui/_workflow.py | 68 +++++++- .../agent_framework_ag_ui/_workflow_run.py | 58 ++++++- .../ag-ui/tests/ag_ui/test_workflow_agent.py | 147 +++++++++++++++++- 3 files changed, 259 insertions(+), 14 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py index aa583856a65..7b41151afba 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py @@ -24,7 +24,7 @@ ToolCallResultEvent, ToolCallStartEvent, ) -from agent_framework import Workflow +from agent_framework import CheckpointStorage, Workflow from ._message_adapters import agui_messages_to_snapshot_format from ._run_common import ( @@ -45,6 +45,13 @@ WorkflowFactory = Callable[[str], Workflow] +# Input-payload keys used to surface workflow checkpointing through ``run(input_data)`` +# without changing the positional call convention used by the FastAPI endpoint. The +# corresponding ``run()`` keyword arguments take precedence over these when both are +# supplied. +_CHECKPOINT_ID_INPUT_KEY = "__ag_ui_checkpoint_id" +_CHECKPOINT_STORAGE_INPUT_KEY = "__ag_ui_checkpoint_storage" + def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]: """Convert AG-UI message event models to plain snapshot dictionaries.""" @@ -276,10 +283,38 @@ def clear_workflow_cache(self) -> None: """Drop all cached thread workflow instances.""" self._workflow_by_thread.clear() - async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]: + async def run( + self, + input_data: dict[str, Any], + *, + checkpoint_storage: CheckpointStorage | None = None, + checkpoint_id: str | None = None, + ) -> AsyncGenerator[BaseEvent]: """Run the wrapped workflow and yield AG-UI events. + Args: + input_data: The AG-UI request payload (a ``RunAgentInput`` dump). + checkpoint_storage: Optional checkpoint storage to enable workflow + checkpointing for this run. When provided, the underlying core + workflow creates a checkpoint at the end of each superstep, matching + ``agent_framework.Workflow.run(checkpoint_storage=...)``. May also be + supplied via the ``input_data`` key ``__ag_ui_checkpoint_storage``; + the keyword argument takes precedence. + checkpoint_id: Optional checkpoint id to resume the workflow from. When + provided, execution restores the persisted workflow state instead of + starting a fresh turn, matching + ``agent_framework.Workflow.run(checkpoint_id=...)``. May also be + supplied via the ``input_data`` key ``__ag_ui_checkpoint_id``; the + keyword argument takes precedence. + Subclasses may override this to provide custom AG-UI streams. + + Note: + Checkpointing (the ``agent_framework`` workflow checkpoint mechanism) is + independent from AG-UI Thread Snapshot persistence (``snapshot_store``). + The two can be used together, but they persist different things: snapshots + capture replayable protocol output for a thread, while checkpoints capture + executor/runtime state for resumable execution. """ thread_id = self._thread_id_from_input(input_data) run_id = str(input_data.get("run_id") or input_data.get("runId") or uuid.uuid4()) @@ -288,7 +323,23 @@ async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]: resume_payload = _extract_resume_payload(input_data) snapshot_store = self.snapshot_store - if snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None: + # Explicit keyword arguments win over values smuggled through input_data so the + # FastAPI endpoint (which calls ``run(input_data)`` positionally) can still opt + # into checkpointing without changing its call site. + if checkpoint_id is None: + checkpoint_id = cast(str | None, input_data.get(_CHECKPOINT_ID_INPUT_KEY)) + if checkpoint_storage is None: + checkpoint_storage = cast(CheckpointStorage | None, input_data.get(_CHECKPOINT_STORAGE_INPUT_KEY)) + + # A checkpoint resume legitimately carries no new messages; it must reach the + # core workflow's restore path rather than replaying a stored thread snapshot. + if ( + checkpoint_id is None + and snapshot_store is not None + and snapshot_scope is not None + and not raw_messages + and resume_payload is None + ): async for event in _hydrate_workflow_thread_snapshot( snapshot_store=snapshot_store, scope=snapshot_scope, @@ -346,8 +397,17 @@ async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]: state_snapshot = make_json_safe(effective_state) if isinstance(state_snapshot, dict): snapshot_builder.state = cast(dict[str, Any], state_snapshot) + # Only forward checkpoint kwargs when checkpointing is requested so the + # non-checkpoint path keeps calling ``run_workflow_stream(input_data, workflow)`` + # exactly as before (preserves the established two-argument call convention). + stream_kwargs: dict[str, Any] = {} + if checkpoint_storage is not None: + stream_kwargs["checkpoint_storage"] = checkpoint_storage + if checkpoint_id is not None: + stream_kwargs["checkpoint_id"] = checkpoint_id + run_error_emitted = False - async for event in run_workflow_stream(input_data, workflow): + async for event in run_workflow_stream(input_data, workflow, **stream_kwargs): if snapshot_builder is not None: snapshot_builder.observe(event) if isinstance(event, RunErrorEvent): diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py index 44d571aef8d..1e3d4d0fff4 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py @@ -24,7 +24,15 @@ ToolCallEndEvent, ToolCallStartEvent, ) -from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, Workflow, WorkflowRunState +from agent_framework import ( + AgentResponse, + AgentResponseUpdate, + CheckpointStorage, + Content, + Message, + Workflow, + WorkflowRunState, +) from ._message_adapters import normalize_agui_input_messages from ._run_common import ( @@ -558,8 +566,24 @@ def _details_code(details: Any) -> str | None: async def run_workflow_stream( input_data: dict[str, Any], workflow: Workflow, + *, + checkpoint_storage: CheckpointStorage | None = None, + checkpoint_id: str | None = None, ) -> AsyncGenerator[BaseEvent]: - """Run a Workflow and emit AG-UI protocol events.""" + """Run a Workflow and emit AG-UI protocol events. + + Args: + input_data: Normalized AG-UI request payload (a ``RunAgentInput`` dump). + workflow: The core ``Workflow`` instance to execute. + checkpoint_storage: Optional checkpoint storage forwarded to the core + workflow. When provided, the workflow creates a checkpoint at the end + of each superstep, mirroring ``Workflow.run(checkpoint_storage=...)``. + checkpoint_id: Optional checkpoint id to resume from. When provided the run + restores the persisted workflow state instead of starting a fresh turn, + mirroring ``Workflow.run(checkpoint_id=...)``. Any incoming messages are + treated as request-info responses (or ignored) rather than a new + start-executor message, so resume stays consistent with the core API. + """ thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4()) run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4()) available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts") @@ -587,7 +611,11 @@ async def run_workflow_stream( if not responses and pending_before_run: responses.update(_single_pending_response_from_value(pending_before_run, _latest_user_text(messages))) - if not responses and pending_before_run: + # A checkpoint resume must always reach ``workflow.run(checkpoint_id=...)`` so the + # core restores persisted state and re-emits any pending requests from the + # checkpoint. ``pending_before_run`` reflects the live (pre-restore) instance, so + # short-circuiting on it here would skip the restore entirely. + if checkpoint_id is None and not responses and pending_before_run: yield RunStartedEvent(run_id=run_id, thread_id=thread_id) for request_event in pending_before_run.values(): request_payload = _request_payload_from_request_event(request_event) @@ -604,7 +632,7 @@ async def run_workflow_stream( yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=pending_interrupts) return - if not responses and not messages: + if checkpoint_id is None and not responses and not messages: yield RunStartedEvent(run_id=run_id, thread_id=thread_id) yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=pending_interrupts) return @@ -640,11 +668,27 @@ def _drain_open_message() -> list[TextMessageEndEvent]: logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props") fwd_kwargs = {} + # Forward checkpoint storage so the core workflow creates a checkpoint at the end + # of each superstep (parity with ``Workflow.run(checkpoint_storage=...)``). + checkpoint_kwargs: dict[str, Any] = {} + if checkpoint_storage is not None: + checkpoint_kwargs["checkpoint_storage"] = checkpoint_storage + try: - if responses: - event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs) + if checkpoint_id is not None: + # Resume from a checkpoint. ``message`` is mutually exclusive with + # ``checkpoint_id`` in the core API, so incoming messages are only + # forwarded as request-info responses (``responses``), never as a new + # start-executor message. ``responses`` + ``checkpoint_id`` performs a + # restore-then-send in a single call. + run_kwargs: dict[str, Any] = {"checkpoint_id": checkpoint_id, **checkpoint_kwargs} + if responses: + run_kwargs["responses"] = responses + event_stream = workflow.run(stream=True, **run_kwargs, **fwd_kwargs) + elif responses: + event_stream = workflow.run(responses=responses, stream=True, **checkpoint_kwargs, **fwd_kwargs) else: - event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs) + event_stream = workflow.run(message=messages, stream=True, **checkpoint_kwargs, **fwd_kwargs) async for event in event_stream: event_type = getattr(event, "type", None) diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py index 858d10370f0..7afbbd75efd 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py @@ -7,13 +7,25 @@ from typing import Any, cast import pytest -from agent_framework import Workflow, WorkflowBuilder, WorkflowContext, executor +from agent_framework import ( + InMemoryCheckpointStorage, + Workflow, + WorkflowBuilder, + WorkflowContext, + executor, + handler, +) +from agent_framework._workflows._executor import Executor from agent_framework_ag_ui import AgentFrameworkWorkflow -async def _run(agent: AgentFrameworkWorkflow, payload: dict[str, Any]) -> list[Any]: - return [event async for event in agent.run(payload)] +async def _run( + agent: AgentFrameworkWorkflow, + payload: dict[str, Any], + **run_kwargs: Any, +) -> list[Any]: + return [event async for event in agent.run(payload, **run_kwargs)] async def test_workflow_wrapper_rejects_workflow_and_factory_at_once() -> None: @@ -110,3 +122,132 @@ async def test_workflow_wrapper_factory_return_type_is_validated() -> None: with pytest.raises(TypeError, match="workflow_factory must return a Workflow instance"): _ = [event async for event in agent.run({"thread_id": "thread-a", "messages": []})] + + +# region checkpointing + + +class _StartExecutor(Executor): + @handler + async def run(self, message: Any, ctx: WorkflowContext[str]) -> None: + del message + await ctx.send_message("hello", target_id="middle") + + +class _MiddleExecutor(Executor): + @handler + async def process(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(f"{message}-processed", target_id="finish") + + +class _FinishExecutor(Executor): + @handler + async def finish(self, message: str, ctx: WorkflowContext[Any, str]) -> None: + await ctx.yield_output(f"{message}-done") + + +def _build_multi_superstep_workflow(storage: InMemoryCheckpointStorage | None = None) -> Workflow: + """Build a start -> middle -> finish workflow that creates a checkpoint per superstep.""" + start = _StartExecutor(id="start") + middle = _MiddleExecutor(id="middle") + finish = _FinishExecutor(id="finish") + builder = WorkflowBuilder(max_iterations=10, start_executor=start) + if storage is not None: + builder = WorkflowBuilder(max_iterations=10, start_executor=start, checkpoint_storage=storage) + return builder.add_edge(start, middle).add_edge(middle, finish).build() + + +async def test_workflow_run_creates_checkpoints_via_storage_kwarg() -> None: + """Passing checkpoint_storage to run() should create workflow checkpoints (parity with core).""" + storage = InMemoryCheckpointStorage() + workflow = _build_multi_superstep_workflow() + agent = AgentFrameworkWorkflow(workflow=workflow) + + events = await _run( + agent, + {"thread_id": "thread-cp", "messages": [{"role": "user", "content": "start"}]}, + checkpoint_storage=storage, + ) + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + assert "RUN_ERROR" not in event_types + + checkpoints = await storage.list_checkpoints(workflow_name=workflow.name) + # One checkpoint per superstep boundary: at least the initial superstep plus follow-ups. + assert len(checkpoints) >= 2 + + +async def test_workflow_run_resumes_from_checkpoint_id() -> None: + """run(checkpoint_id=...) should restore persisted state and finish the workflow.""" + storage = InMemoryCheckpointStorage() + workflow = _build_multi_superstep_workflow(storage) + agent = AgentFrameworkWorkflow(workflow=workflow) + + # First run: execute to completion while checkpoints are written. + first_events = await _run( + agent, + {"thread_id": "thread-cp", "messages": [{"role": "user", "content": "start"}]}, + ) + assert "RUN_ERROR" not in [event.type for event in first_events] + + checkpoints = sorted( + await storage.list_checkpoints(workflow_name=workflow.name), + key=lambda checkpoint: checkpoint.timestamp, + ) + assert checkpoints, "expected the run to create at least one checkpoint" + # Resume from the earliest checkpoint so middle -> finish replays and re-produces output. + resume_checkpoint_id = checkpoints[0].checkpoint_id + + # Resume on the same thread (same underlying workflow instance) from the checkpoint. + resumed_events = await _run( + agent, + {"thread_id": "thread-cp", "messages": []}, + checkpoint_id=resume_checkpoint_id, + checkpoint_storage=storage, + ) + + resumed_types = [event.type for event in resumed_events] + assert "RUN_STARTED" in resumed_types + assert "RUN_FINISHED" in resumed_types + assert "RUN_ERROR" not in resumed_types + + # The resumed run should reproduce the final assistant output ("hello-processed-done"). + resumed_text = "".join( + getattr(event, "delta", "") for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT" + ) + assert "done" in resumed_text + + +async def test_workflow_run_reads_checkpoint_params_from_input_data() -> None: + """Checkpoint params smuggled through input_data should be honored (endpoint call convention).""" + storage = InMemoryCheckpointStorage() + workflow = _build_multi_superstep_workflow() + agent = AgentFrameworkWorkflow(workflow=workflow) + + events = await _run( + agent, + { + "thread_id": "thread-cp-input", + "messages": [{"role": "user", "content": "start"}], + "__ag_ui_checkpoint_storage": storage, + }, + ) + + assert "RUN_ERROR" not in [event.type for event in events] + checkpoints = await storage.list_checkpoints(workflow_name=workflow.name) + assert len(checkpoints) >= 1 + + +async def test_workflow_run_without_checkpointing_is_unchanged() -> None: + """Existing run(input_data) calls keep working unchanged when no checkpoint args are given.""" + workflow = _build_multi_superstep_workflow() + agent = AgentFrameworkWorkflow(workflow=workflow) + + events = await _run(agent, {"thread_id": "thread-plain", "messages": [{"role": "user", "content": "start"}]}) + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + assert "RUN_ERROR" not in event_types