From 8c24a423b2abad461b0f106593f2948af7f0bb35 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Thu, 9 Apr 2026 19:51:37 -0300 Subject: [PATCH] feat: support coroutine functions as invoke targets (#610) Coroutine functions passed as invoke targets were silently broken: the coroutine was never awaited and the coroutine object was passed as data in the done.invoke event. Now the async engine detects coroutine callbacks and IInvoke handlers with async run() and awaits them directly on the event loop instead of routing through run_in_executor. Also adds InvalidDefinition when async IInvoke handlers are used with the sync engine, and ignores the local specs/ directory. Signed-off-by: Fernando Macedo --- .gitignore | 3 + statemachine/invoke.py | 36 ++++++++- tests/test_invoke.py | 178 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index ba144f94..01a3b318 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ docs/sg_execution_times.rst # Temporary files tmp/ + +# Local specs +specs/ diff --git a/statemachine/invoke.py b/statemachine/invoke.py index 5d09c14d..05e43dbc 100644 --- a/statemachine/invoke.py +++ b/statemachine/invoke.py @@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from dataclasses import field +from inspect import iscoroutinefunction from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -134,6 +135,16 @@ def _needs_wrapping(item: Any) -> bool: return False +def _has_async_run(handler: Any) -> bool: + """Check if a handler (or its wrapped inner) has an async ``run()`` method.""" + if iscoroutinefunction(getattr(handler, "run", None)): + return True + if isinstance(handler, _InvokeCallableWrapper): + inner = handler._invoke_handler + return iscoroutinefunction(getattr(inner, "run", None)) + return False + + @dataclass class InvokeContext: """Context passed to invoke handlers.""" @@ -357,6 +368,15 @@ def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs): # Use meta.func to find the original (unwrapped) handler; the callback # system wraps everything in a signature_adapter closure. handler = self._resolve_handler(callback.meta.func) + + if handler is not None and _has_async_run(handler): + from .exceptions import InvalidDefinition + + raise InvalidDefinition( + "Cannot use IInvoke with async run() on the sync engine. " + "Add an async callback or listener to enable the async engine." + ) + ctx = self._make_context(state, event_kwargs, handler=handler) invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx) @@ -448,12 +468,22 @@ async def _run_async_handler( invocation: Invocation, ): try: - loop = asyncio.get_running_loop() - if handler is not None: - # Run handler.run(ctx) in a thread executor so blocking I/O + if handler is not None and _has_async_run(handler): + # Async IInvoke: call run() and await the coroutine directly + # on the event loop (no executor needed). + result = await handler.run(ctx) + elif handler is not None: + # Sync IInvoke: run in a thread executor so blocking I/O # doesn't freeze the event loop. + loop = asyncio.get_running_loop() result = await loop.run_in_executor(None, handler.run, ctx) + elif callback._iscoro: + # Coroutine callback: await directly on the event loop. + result = await callback(ctx=ctx, machine=ctx.machine, **ctx.kwargs) else: + # Sync callback: run in a thread executor so blocking I/O + # doesn't freeze the event loop. + loop = asyncio.get_running_loop() result = await loop.run_in_executor( None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs) ) diff --git a/tests/test_invoke.py b/tests/test_invoke.py index 54d35249..1e420030 100644 --- a/tests/test_invoke.py +++ b/tests/test_invoke.py @@ -829,6 +829,184 @@ def test_on_cancel_before_run(self): group.on_cancel() +class TestCoroutineFunctionAsInvokeTarget: + """Coroutine functions should work as invoke targets on the async engine.""" + + async def test_coroutine_invoke_returns_awaited_result(self): + """An async function used as invoke target should be awaited and return its value.""" + from tests.conftest import SMRunner + + async def async_loader(): + return 42 + + class SM(StateChart): + loading = State(initial=True, invoke=async_loader) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + self.result = data + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert sm.result == 42 + + async def test_coroutine_invoke_error_sends_error_execution(self): + """An async invoke that raises should send error.execution.""" + from tests.conftest import SMRunner + + async def failing_loader(): + raise ValueError("async boom") + + class SM(StateChart): + loading = State(initial=True, invoke=failing_loader) + error_state = State(final=True) + error_execution = loading.to(error_state) + + def on_enter_error_state(self, error=None, **kwargs): + self.caught_error = error + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "error_state" in sm.configuration_values + assert isinstance(sm.caught_error, ValueError) + assert str(sm.caught_error) == "async boom" + + async def test_coroutine_invoke_cancelled_on_state_exit(self): + """An async invoke should be cancelled when the owning state is exited.""" + from tests.conftest import SMRunner + + cancel_observed = [] + + async def slow_loader(): + import asyncio + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancel_observed.append(True) + raise + return "should not reach" + + class SM(StateChart): + loading = State(initial=True, invoke=slow_loader) + stopped = State(final=True) + cancel = loading.to(stopped) + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.05) + await sm_runner.send(sm, "cancel") + await sm_runner.sleep(0.05) + + assert "stopped" in sm.configuration_values + + +class TestAsyncIInvokeInstance: + """IInvoke instances with async def run() should be awaited on the async engine.""" + + async def test_async_iinvoke_instance(self): + """An IInvoke instance with async run() should be awaited.""" + from tests.conftest import SMRunner + + class AsyncHandler: + async def run(self, ctx): + return "async_result" + + handler = AsyncHandler() + + class SM(StateChart): + loading = State(initial=True, invoke=handler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + self.result = data + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert sm.result == "async_result" + + +class TestAsyncIInvokeClass: + """IInvoke classes with async def run() should be instantiated and awaited.""" + + async def test_async_iinvoke_class(self): + """An IInvoke class with async run() should be instantiated and its run() awaited.""" + from tests.conftest import SMRunner + + class AsyncHandler: + async def run(self, ctx): + return "class_async_result" + + class SM(StateChart): + loading = State(initial=True, invoke=AsyncHandler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + self.result = data + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert sm.result == "class_async_result" + + +class TestAsyncIInvokeOnSyncEngine: + """IInvoke with async run() on the sync engine should raise InvalidDefinition.""" + + def test_async_iinvoke_instance_on_sync_engine_raises(self): + """An IInvoke instance with async run() should fail clearly on the sync engine.""" + import pytest + from statemachine.exceptions import InvalidDefinition + + class AsyncHandler: + async def run(self, ctx): + return "unreachable" + + handler = AsyncHandler() + + class SM(StateChart): + loading = State(initial=True, invoke=handler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + with pytest.raises(InvalidDefinition): + SM() + + def test_async_iinvoke_class_on_sync_engine_raises(self): + """An IInvoke class with async run() should fail clearly on the sync engine.""" + import pytest + from statemachine.exceptions import InvalidDefinition + + class AsyncHandler: + async def run(self, ctx): + return "unreachable" + + class SM(StateChart): + loading = State(initial=True, invoke=AsyncHandler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + with pytest.raises(InvalidDefinition): + SM() + + class TestDoneInvokeEventFactory: """done_invoke_ prefix works with both TransitionList and Event."""