From 8ff1dfc67190a2f5ae475834f3e9a279ba29f2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Wed, 18 Mar 2026 14:25:29 -0300 Subject: [PATCH 1/3] core: add flexible_api decorator to fix mypy override errors on class-based actions Adds a `flexible_api` decorator that users can apply to `run`, `stream_run`, or `run_and_update` overrides that use explicit parameters instead of `**run_kwargs`. This prevents mypy [override] errors caused by narrowing the base-class signature. Closes #457 --- burr/core/__init__.py | 3 ++- burr/core/action.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/burr/core/__init__.py b/burr/core/__init__.py index c4da5a48e..cfa00f034 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from burr.core.action import Action, Condition, Result, action, default, expr, when +from burr.core.action import Action, Condition, Result, action, default, expr, flexible_api, when from burr.core.application import ( Application, ApplicationBuilder, @@ -35,6 +35,7 @@ "Condition", "default", "expr", + "flexible_api", "Result", "State", "when", diff --git a/burr/core/action.py b/burr/core/action.py index b2e7c16d3..12d1644f4 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -100,8 +100,49 @@ def visit_Subscript(self, node): ) +from functools import wraps + from burr.core.typing import ActionSchema + +def flexible_api(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator for ``run``, ``stream_run``, and ``run_and_update`` overrides + that declare explicit parameters instead of ``**run_kwargs``. + + Applying this decorator prevents mypy ``[override]`` errors caused by + narrowing the base-class signature (which uses ``**run_kwargs``). + + Example usage:: + + from burr.core import Action, State, flexible_api + + class Counter(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + @flexible_api + def run(self, state: State, increment_by: int) -> dict: + return {"counter": state["counter"] + increment_by} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["increment_by"] + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + return wrapper + # This is here to make accessing the pydantic actions easier # we just attach them to action so you can call `@action.pyddantic...` # The IDE will like it better and thus be able to auto-complete/type-check From b4ab46195a388b71c4c122f5a481ca1921763f99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Wed, 18 Mar 2026 14:33:11 -0300 Subject: [PATCH 2/3] style: apply black formatting --- burr/core/action.py | 1 + 1 file changed, 1 insertion(+) diff --git a/burr/core/action.py b/burr/core/action.py index 12d1644f4..a2c3714e3 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -143,6 +143,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper + # This is here to make accessing the pydantic actions easier # we just attach them to action so you can call `@action.pyddantic...` # The IDE will like it better and thus be able to auto-complete/type-check From 88f669ca06d5838b1313a24a09ae9d6041e94daa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Tue, 24 Mar 2026 10:31:04 -0300 Subject: [PATCH 3/3] Rename flexible_api to type_eraser and fix async support Address review feedback: rename decorator per maintainer suggestion. Fix critical bug where @wraps on async/generator functions broke is_async() detection. Add tests for all function types. --- burr/core/__init__.py | 4 +- burr/core/action.py | 31 ++++++++- tests/core/test_action.py | 141 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 5 deletions(-) diff --git a/burr/core/__init__.py b/burr/core/__init__.py index cfa00f034..aa2f75a46 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from burr.core.action import Action, Condition, Result, action, default, expr, flexible_api, when +from burr.core.action import Action, Condition, Result, action, default, expr, type_eraser, when from burr.core.application import ( Application, ApplicationBuilder, @@ -35,7 +35,7 @@ "Condition", "default", "expr", - "flexible_api", + "type_eraser", "Result", "State", "when", diff --git a/burr/core/action.py b/burr/core/action.py index a2c3714e3..b88a051fb 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -105,7 +105,7 @@ def visit_Subscript(self, node): from burr.core.typing import ActionSchema -def flexible_api(func: Callable[..., Any]) -> Callable[..., Any]: +def type_eraser(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator for ``run``, ``stream_run``, and ``run_and_update`` overrides that declare explicit parameters instead of ``**run_kwargs``. @@ -114,14 +114,14 @@ def flexible_api(func: Callable[..., Any]) -> Callable[..., Any]: Example usage:: - from burr.core import Action, State, flexible_api + from burr.core import Action, State, type_eraser class Counter(Action): @property def reads(self) -> list[str]: return ["counter"] - @flexible_api + @type_eraser def run(self, state: State, increment_by: int) -> dict: return {"counter": state["counter"] + increment_by} @@ -137,6 +137,31 @@ def inputs(self) -> list[str]: return ["increment_by"] """ + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + return async_wrapper + + if inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + + if inspect.isgeneratorfunction(func): + + @wraps(func) + def gen_wrapper(*args: Any, **kwargs: Any) -> Any: + yield from func(*args, **kwargs) + + return gen_wrapper + @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 83fecf3b3..2366b1d04 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -38,6 +38,7 @@ default, derive_inputs_from_fn, streaming_action, + type_eraser, ) @@ -878,3 +879,143 @@ def good_action(state: MyState): from burr.core.action import create_action create_action(good_action, name="test") + + +class TestTypeEraser: + def test_sync_run(self): + class MyAction(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + @type_eraser + def run(self, state: State, increment_by: int) -> dict: + return {"counter": state["counter"] + increment_by} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["increment_by"] + + a = MyAction() + result = a.run(State({"counter": 0}), increment_by=5) + assert result == {"counter": 5} + assert a.is_async() is False + + def test_async_run(self): + class MyAsyncAction(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + @type_eraser + async def run(self, state: State, increment_by: int) -> dict: + return {"counter": state["counter"] + increment_by} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["increment_by"] + + a = MyAsyncAction() + assert a.is_async() is True + result = asyncio.run(a.run(State({"counter": 0}), increment_by=5)) + assert result == {"counter": 5} + + def test_sync_stream_run(self): + class MyStreamingAction(StreamingAction): + @property + def reads(self) -> list[str]: + return ["items"] + + @type_eraser + def stream_run(self, state: State, prefix: str) -> Generator[dict, None, None]: + for item in state["items"]: + yield {"val": f"{prefix}_{item}"} + + @property + def writes(self) -> list[str]: + return ["result"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["prefix"] + + a = MyStreamingAction() + results = list(a.stream_run(State({"items": ["a", "b"]}), prefix="x")) + assert results == [{"val": "x_a"}, {"val": "x_b"}] + + def test_async_stream_run(self): + class MyAsyncStreamingAction(AsyncStreamingAction): + @property + def reads(self) -> list[str]: + return ["items"] + + @type_eraser + async def stream_run(self, state: State, prefix: str) -> AsyncGenerator[dict, None]: + for item in state["items"]: + yield {"val": f"{prefix}_{item}"} + + @property + def writes(self) -> list[str]: + return ["result"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + @property + def inputs(self) -> list[str]: + return ["prefix"] + + a = MyAsyncStreamingAction() + assert a.is_async() is True + + async def collect(): + return [item async for item in a.stream_run(State({"items": ["a", "b"]}), prefix="x")] + + results = asyncio.run(collect()) + assert results == [{"val": "x_a"}, {"val": "x_b"}] + + def test_preserves_wrapped_name(self): + class MyAction(Action): + @property + def reads(self) -> list[str]: + return [] + + @type_eraser + def run(self, state: State, custom_param: str) -> dict: + return {} + + @property + def writes(self) -> list[str]: + return [] + + def update(self, result: dict, state: State) -> State: + return state + + @property + def inputs(self) -> list[str]: + return [] + + assert MyAction().run.__name__ == "run" + assert MyAction().run.__wrapped__.__name__ == "run" + + def test_exported_from_burr_core(self): + from burr.core import type_eraser as te + + assert te is type_eraser