diff --git a/.gitignore b/.gitignore index ba144f94..01a3b318 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ docs/sg_execution_times.rst # Temporary files tmp/ + +# Local specs +specs/ diff --git a/statemachine/invoke.py b/statemachine/invoke.py index 5d09c14d..05e43dbc 100644 --- a/statemachine/invoke.py +++ b/statemachine/invoke.py @@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from dataclasses import field +from inspect import iscoroutinefunction from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -134,6 +135,16 @@ def _needs_wrapping(item: Any) -> bool: return False +def _has_async_run(handler: Any) -> bool: + """Check if a handler (or its wrapped inner) has an async ``run()`` method.""" + if iscoroutinefunction(getattr(handler, "run", None)): + return True + if isinstance(handler, _InvokeCallableWrapper): + inner = handler._invoke_handler + return iscoroutinefunction(getattr(inner, "run", None)) + return False + + @dataclass class InvokeContext: """Context passed to invoke handlers.""" @@ -357,6 +368,15 @@ def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs): # Use meta.func to find the original (unwrapped) handler; the callback # system wraps everything in a signature_adapter closure. handler = self._resolve_handler(callback.meta.func) + + if handler is not None and _has_async_run(handler): + from .exceptions import InvalidDefinition + + raise InvalidDefinition( + "Cannot use IInvoke with async run() on the sync engine. " + "Add an async callback or listener to enable the async engine." + ) + ctx = self._make_context(state, event_kwargs, handler=handler) invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx) @@ -448,12 +468,22 @@ async def _run_async_handler( invocation: Invocation, ): try: - loop = asyncio.get_running_loop() - if handler is not None: - # Run handler.run(ctx) in a thread executor so blocking I/O + if handler is not None and _has_async_run(handler): + # Async IInvoke: call run() and await the coroutine directly + # on the event loop (no executor needed). + result = await handler.run(ctx) + elif handler is not None: + # Sync IInvoke: run in a thread executor so blocking I/O # doesn't freeze the event loop. + loop = asyncio.get_running_loop() result = await loop.run_in_executor(None, handler.run, ctx) + elif callback._iscoro: + # Coroutine callback: await directly on the event loop. + result = await callback(ctx=ctx, machine=ctx.machine, **ctx.kwargs) else: + # Sync callback: run in a thread executor so blocking I/O + # doesn't freeze the event loop. + loop = asyncio.get_running_loop() result = await loop.run_in_executor( None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs) ) diff --git a/tests/test_invoke.py b/tests/test_invoke.py index 54d35249..1e420030 100644 --- a/tests/test_invoke.py +++ b/tests/test_invoke.py @@ -829,6 +829,184 @@ def test_on_cancel_before_run(self): group.on_cancel() +class TestCoroutineFunctionAsInvokeTarget: + """Coroutine functions should work as invoke targets on the async engine.""" + + async def test_coroutine_invoke_returns_awaited_result(self): + """An async function used as invoke target should be awaited and return its value.""" + from tests.conftest import SMRunner + + async def async_loader(): + return 42 + + class SM(StateChart): + loading = State(initial=True, invoke=async_loader) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + self.result = data + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert sm.result == 42 + + async def test_coroutine_invoke_error_sends_error_execution(self): + """An async invoke that raises should send error.execution.""" + from tests.conftest import SMRunner + + async def failing_loader(): + raise ValueError("async boom") + + class SM(StateChart): + loading = State(initial=True, invoke=failing_loader) + error_state = State(final=True) + error_execution = loading.to(error_state) + + def on_enter_error_state(self, error=None, **kwargs): + self.caught_error = error + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "error_state" in sm.configuration_values + assert isinstance(sm.caught_error, ValueError) + assert str(sm.caught_error) == "async boom" + + async def test_coroutine_invoke_cancelled_on_state_exit(self): + """An async invoke should be cancelled when the owning state is exited.""" + from tests.conftest import SMRunner + + cancel_observed = [] + + async def slow_loader(): + import asyncio + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancel_observed.append(True) + raise + return "should not reach" + + class SM(StateChart): + loading = State(initial=True, invoke=slow_loader) + stopped = State(final=True) + cancel = loading.to(stopped) + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.05) + await sm_runner.send(sm, "cancel") + await sm_runner.sleep(0.05) + + assert "stopped" in sm.configuration_values + + +class TestAsyncIInvokeInstance: + """IInvoke instances with async def run() should be awaited on the async engine.""" + + async def test_async_iinvoke_instance(self): + """An IInvoke instance with async run() should be awaited.""" + from tests.conftest import SMRunner + + class AsyncHandler: + async def run(self, ctx): + return "async_result" + + handler = AsyncHandler() + + class SM(StateChart): + loading = State(initial=True, invoke=handler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + self.result = data + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert sm.result == "async_result" + + +class TestAsyncIInvokeClass: + """IInvoke classes with async def run() should be instantiated and awaited.""" + + async def test_async_iinvoke_class(self): + """An IInvoke class with async run() should be instantiated and its run() awaited.""" + from tests.conftest import SMRunner + + class AsyncHandler: + async def run(self, ctx): + return "class_async_result" + + class SM(StateChart): + loading = State(initial=True, invoke=AsyncHandler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + def on_enter_ready(self, data=None, **kwargs): + self.result = data + + sm_runner = SMRunner(is_async=True) + sm = await sm_runner.start(SM) + await sm_runner.sleep(0.2) + await sm_runner.processing_loop(sm) + + assert "ready" in sm.configuration_values + assert sm.result == "class_async_result" + + +class TestAsyncIInvokeOnSyncEngine: + """IInvoke with async run() on the sync engine should raise InvalidDefinition.""" + + def test_async_iinvoke_instance_on_sync_engine_raises(self): + """An IInvoke instance with async run() should fail clearly on the sync engine.""" + import pytest + from statemachine.exceptions import InvalidDefinition + + class AsyncHandler: + async def run(self, ctx): + return "unreachable" + + handler = AsyncHandler() + + class SM(StateChart): + loading = State(initial=True, invoke=handler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + with pytest.raises(InvalidDefinition): + SM() + + def test_async_iinvoke_class_on_sync_engine_raises(self): + """An IInvoke class with async run() should fail clearly on the sync engine.""" + import pytest + from statemachine.exceptions import InvalidDefinition + + class AsyncHandler: + async def run(self, ctx): + return "unreachable" + + class SM(StateChart): + loading = State(initial=True, invoke=AsyncHandler) + ready = State(final=True) + done_invoke_loading = loading.to(ready) + + with pytest.raises(InvalidDefinition): + SM() + + class TestDoneInvokeEventFactory: """done_invoke_ prefix works with both TransitionList and Event."""