Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ docs/sg_execution_times.rst

# Temporary files
tmp/

# Local specs
specs/
36 changes: 33 additions & 3 deletions statemachine/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from dataclasses import field
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -134,6 +135,16 @@ def _needs_wrapping(item: Any) -> bool:
return False


def _has_async_run(handler: Any) -> bool:
"""Check if a handler (or its wrapped inner) has an async ``run()`` method."""
if iscoroutinefunction(getattr(handler, "run", None)):
return True
if isinstance(handler, _InvokeCallableWrapper):
inner = handler._invoke_handler
return iscoroutinefunction(getattr(inner, "run", None))
return False


@dataclass
class InvokeContext:
"""Context passed to invoke handlers."""
Expand Down Expand Up @@ -357,6 +368,15 @@ def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs):
# Use meta.func to find the original (unwrapped) handler; the callback
# system wraps everything in a signature_adapter closure.
handler = self._resolve_handler(callback.meta.func)

if handler is not None and _has_async_run(handler):
from .exceptions import InvalidDefinition

raise InvalidDefinition(
"Cannot use IInvoke with async run() on the sync engine. "
"Add an async callback or listener to enable the async engine."
)

ctx = self._make_context(state, event_kwargs, handler=handler)
invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx)

Expand Down Expand Up @@ -448,12 +468,22 @@ async def _run_async_handler(
invocation: Invocation,
):
try:
loop = asyncio.get_running_loop()
if handler is not None:
# Run handler.run(ctx) in a thread executor so blocking I/O
if handler is not None and _has_async_run(handler):
# Async IInvoke: call run() and await the coroutine directly
# on the event loop (no executor needed).
result = await handler.run(ctx)
elif handler is not None:
# Sync IInvoke: run in a thread executor so blocking I/O
# doesn't freeze the event loop.
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, handler.run, ctx)
elif callback._iscoro:
# Coroutine callback: await directly on the event loop.
result = await callback(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
else:
# Sync callback: run in a thread executor so blocking I/O
# doesn't freeze the event loop.
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
)
Expand Down
178 changes: 178 additions & 0 deletions tests/test_invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,184 @@
group.on_cancel()


class TestCoroutineFunctionAsInvokeTarget:
"""Coroutine functions should work as invoke targets on the async engine."""

async def test_coroutine_invoke_returns_awaited_result(self):
"""An async function used as invoke target should be awaited and return its value."""
from tests.conftest import SMRunner

async def async_loader():

Check warning on line 839 in tests/test_invoke.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=fgmacedo_python-statemachine&issues=AZ10cwZuQWQRkjYcLeB4&open=AZ10cwZuQWQRkjYcLeB4&pullRequest=611
return 42

class SM(StateChart):
loading = State(initial=True, invoke=async_loader)
ready = State(final=True)
done_invoke_loading = loading.to(ready)

def on_enter_ready(self, data=None, **kwargs):
self.result = data

sm_runner = SMRunner(is_async=True)
sm = await sm_runner.start(SM)
await sm_runner.sleep(0.2)
await sm_runner.processing_loop(sm)

assert "ready" in sm.configuration_values
assert sm.result == 42

async def test_coroutine_invoke_error_sends_error_execution(self):
"""An async invoke that raises should send error.execution."""
from tests.conftest import SMRunner

async def failing_loader():
raise ValueError("async boom")

class SM(StateChart):
loading = State(initial=True, invoke=failing_loader)
error_state = State(final=True)
error_execution = loading.to(error_state)

def on_enter_error_state(self, error=None, **kwargs):
self.caught_error = error

sm_runner = SMRunner(is_async=True)
sm = await sm_runner.start(SM)
await sm_runner.sleep(0.2)
await sm_runner.processing_loop(sm)

assert "error_state" in sm.configuration_values
assert isinstance(sm.caught_error, ValueError)
assert str(sm.caught_error) == "async boom"

async def test_coroutine_invoke_cancelled_on_state_exit(self):
"""An async invoke should be cancelled when the owning state is exited."""
from tests.conftest import SMRunner

cancel_observed = []

async def slow_loader():
import asyncio

try:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancel_observed.append(True)
raise
return "should not reach"

class SM(StateChart):
loading = State(initial=True, invoke=slow_loader)
stopped = State(final=True)
cancel = loading.to(stopped)

sm_runner = SMRunner(is_async=True)
sm = await sm_runner.start(SM)
await sm_runner.sleep(0.05)
await sm_runner.send(sm, "cancel")
await sm_runner.sleep(0.05)

assert "stopped" in sm.configuration_values


class TestAsyncIInvokeInstance:
"""IInvoke instances with async def run() should be awaited on the async engine."""

async def test_async_iinvoke_instance(self):
"""An IInvoke instance with async run() should be awaited."""
from tests.conftest import SMRunner

class AsyncHandler:
async def run(self, ctx):

Check warning on line 920 in tests/test_invoke.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=fgmacedo_python-statemachine&issues=AZ10cwZuQWQRkjYcLeB5&open=AZ10cwZuQWQRkjYcLeB5&pullRequest=611
return "async_result"

handler = AsyncHandler()

class SM(StateChart):
loading = State(initial=True, invoke=handler)
ready = State(final=True)
done_invoke_loading = loading.to(ready)

def on_enter_ready(self, data=None, **kwargs):
self.result = data

sm_runner = SMRunner(is_async=True)
sm = await sm_runner.start(SM)
await sm_runner.sleep(0.2)
await sm_runner.processing_loop(sm)

assert "ready" in sm.configuration_values
assert sm.result == "async_result"


class TestAsyncIInvokeClass:
"""IInvoke classes with async def run() should be instantiated and awaited."""

async def test_async_iinvoke_class(self):
"""An IInvoke class with async run() should be instantiated and its run() awaited."""
from tests.conftest import SMRunner

class AsyncHandler:
async def run(self, ctx):

Check warning on line 950 in tests/test_invoke.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=fgmacedo_python-statemachine&issues=AZ10cwZuQWQRkjYcLeB6&open=AZ10cwZuQWQRkjYcLeB6&pullRequest=611
return "class_async_result"

class SM(StateChart):
loading = State(initial=True, invoke=AsyncHandler)
ready = State(final=True)
done_invoke_loading = loading.to(ready)

def on_enter_ready(self, data=None, **kwargs):
self.result = data

sm_runner = SMRunner(is_async=True)
sm = await sm_runner.start(SM)
await sm_runner.sleep(0.2)
await sm_runner.processing_loop(sm)

assert "ready" in sm.configuration_values
assert sm.result == "class_async_result"


class TestAsyncIInvokeOnSyncEngine:
"""IInvoke with async run() on the sync engine should raise InvalidDefinition."""

def test_async_iinvoke_instance_on_sync_engine_raises(self):
"""An IInvoke instance with async run() should fail clearly on the sync engine."""
import pytest
from statemachine.exceptions import InvalidDefinition

class AsyncHandler:
async def run(self, ctx):

Check warning on line 979 in tests/test_invoke.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=fgmacedo_python-statemachine&issues=AZ10cwZuQWQRkjYcLeB7&open=AZ10cwZuQWQRkjYcLeB7&pullRequest=611
return "unreachable"

handler = AsyncHandler()

class SM(StateChart):
loading = State(initial=True, invoke=handler)
ready = State(final=True)
done_invoke_loading = loading.to(ready)

with pytest.raises(InvalidDefinition):
SM()

def test_async_iinvoke_class_on_sync_engine_raises(self):
"""An IInvoke class with async run() should fail clearly on the sync engine."""
import pytest
from statemachine.exceptions import InvalidDefinition

class AsyncHandler:
async def run(self, ctx):

Check warning on line 998 in tests/test_invoke.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=fgmacedo_python-statemachine&issues=AZ10cwZuQWQRkjYcLeB8&open=AZ10cwZuQWQRkjYcLeB8&pullRequest=611
return "unreachable"

class SM(StateChart):
loading = State(initial=True, invoke=AsyncHandler)
ready = State(final=True)
done_invoke_loading = loading.to(ready)

with pytest.raises(InvalidDefinition):
SM()


class TestDoneInvokeEventFactory:
"""done_invoke_ prefix works with both TransitionList and Event."""

Expand Down
Loading