diff --git a/burr/core/__init__.py b/burr/core/__init__.py index c4da5a48e..aa2f75a46 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from burr.core.action import Action, Condition, Result, action, default, expr, when +from burr.core.action import Action, Condition, Result, action, default, expr, type_eraser, when from burr.core.application import ( Application, ApplicationBuilder, @@ -35,6 +35,7 @@ "Condition", "default", "expr", + "type_eraser", "Result", "State", "when", diff --git a/burr/core/action.py b/burr/core/action.py index b2e7c16d3..b88a051fb 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -100,8 +100,75 @@ def visit_Subscript(self, node): ) +from functools import wraps + from burr.core.typing import ActionSchema + +def type_eraser(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator for ``run``, ``stream_run``, and ``run_and_update`` overrides + that declare explicit parameters instead of ``**run_kwargs``. + + Applying this decorator prevents mypy ``[override]`` errors caused by + narrowing the base-class signature (which uses ``**run_kwargs``). + + Example usage:: + + from burr.core import Action, State, type_eraser + + class Counter(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + @type_eraser + def run(self, state: State, increment_by: int) -> dict: + return {"counter": state["counter"] + increment_by} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["increment_by"] + """ + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + return async_wrapper + + if inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + + if inspect.isgeneratorfunction(func): + + @wraps(func) + def gen_wrapper(*args: Any, **kwargs: Any) -> Any: + yield from func(*args, **kwargs) + + return gen_wrapper + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + return wrapper + + # This is here to make accessing the pydantic actions easier # we just attach them to action so you can call `@action.pyddantic...` # The IDE will like it better and thus be able to auto-complete/type-check diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 83fecf3b3..2366b1d04 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -38,6 +38,7 @@ default, derive_inputs_from_fn, streaming_action, + type_eraser, ) @@ -878,3 +879,143 @@ def good_action(state: MyState): from burr.core.action import create_action create_action(good_action, name="test") + + +class TestTypeEraser: + def test_sync_run(self): + class MyAction(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + @type_eraser + def run(self, state: State, increment_by: int) -> dict: + return {"counter": state["counter"] + increment_by} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["increment_by"] + + a = MyAction() + result = a.run(State({"counter": 0}), increment_by=5) + assert result == {"counter": 5} + assert a.is_async() is False + + def test_async_run(self): + class MyAsyncAction(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + @type_eraser + async def run(self, state: State, increment_by: int) -> dict: + return {"counter": state["counter"] + increment_by} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["increment_by"] + + a = MyAsyncAction() + assert a.is_async() is True + result = asyncio.run(a.run(State({"counter": 0}), increment_by=5)) + assert result == {"counter": 5} + + def test_sync_stream_run(self): + class MyStreamingAction(StreamingAction): + @property + def reads(self) -> list[str]: + return ["items"] + + @type_eraser + def stream_run(self, state: State, prefix: str) -> Generator[dict, None, None]: + for item in state["items"]: + yield {"val": f"{prefix}_{item}"} + + @property + def writes(self) -> list[str]: + return ["result"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["prefix"] + + a = MyStreamingAction() + results = list(a.stream_run(State({"items": ["a", "b"]}), prefix="x")) + assert results == [{"val": "x_a"}, {"val": "x_b"}] + + def test_async_stream_run(self): + class MyAsyncStreamingAction(AsyncStreamingAction): + @property + def reads(self) -> list[str]: + return ["items"] + + @type_eraser + async def stream_run(self, state: State, prefix: str) -> AsyncGenerator[dict, None]: + for item in state["items"]: + yield {"val": f"{prefix}_{item}"} + + @property + def writes(self) -> list[str]: + return ["result"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["prefix"] + + a = MyAsyncStreamingAction() + assert a.is_async() is True + + async def collect(): + return [item async for item in a.stream_run(State({"items": ["a", "b"]}), prefix="x")] + + results = asyncio.run(collect()) + assert results == [{"val": "x_a"}, {"val": "x_b"}] + + def test_preserves_wrapped_name(self): + class MyAction(Action): + @property + def reads(self) -> list[str]: + return [] + + @type_eraser + def run(self, state: State, custom_param: str) -> dict: + return {} + + @property + def writes(self) -> list[str]: + return [] + + def update(self, result: dict, state: State) -> State: + return state + + @property + def inputs(self) -> list[str]: + return [] + + assert MyAction().run.__name__ == "run" + assert MyAction().run.__wrapped__.__name__ == "run" + + def test_exported_from_burr_core(self): + from burr.core import type_eraser as te + + assert te is type_eraser