Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion burr/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from burr.core.action import Action, Condition, Result, action, default, expr, when
from burr.core.action import Action, Condition, Result, action, default, expr, type_eraser, when
from burr.core.application import (
Application,
ApplicationBuilder,
Expand All @@ -35,6 +35,7 @@
"Condition",
"default",
"expr",
"type_eraser",
"Result",
"State",
"when",
Expand Down
67 changes: 67 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,75 @@ def visit_Subscript(self, node):
)


from functools import wraps

from burr.core.typing import ActionSchema


def type_eraser(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator for ``run``, ``stream_run``, and ``run_and_update`` overrides
that declare explicit parameters instead of ``**run_kwargs``.

Applying this decorator prevents mypy ``[override]`` errors caused by
narrowing the base-class signature (which uses ``**run_kwargs``).

Example usage::

from burr.core import Action, State, type_eraser

class Counter(Action):
@property
def reads(self) -> list[str]:
return ["counter"]

@type_eraser
def run(self, state: State, increment_by: int) -> dict:
return {"counter": state["counter"] + increment_by}

@property
def writes(self) -> list[str]:
return ["counter"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)

@property
def inputs(self) -> list[str]:
return ["increment_by"]
"""

if inspect.iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
return await func(*args, **kwargs)

return async_wrapper

if inspect.isasyncgenfunction(func):

@wraps(func)
async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
async for item in func(*args, **kwargs):
yield item

return async_gen_wrapper

if inspect.isgeneratorfunction(func):

@wraps(func)
def gen_wrapper(*args: Any, **kwargs: Any) -> Any:
yield from func(*args, **kwargs)

return gen_wrapper

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

return wrapper


# This is here to make accessing the pydantic actions easier
# we just attach them to action so you can call `@action.pyddantic...`
# The IDE will like it better and thus be able to auto-complete/type-check
Expand Down
141 changes: 141 additions & 0 deletions tests/core/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
default,
derive_inputs_from_fn,
streaming_action,
type_eraser,
)


Expand Down Expand Up @@ -878,3 +879,143 @@ def good_action(state: MyState):
from burr.core.action import create_action

create_action(good_action, name="test")


class TestTypeEraser:
def test_sync_run(self):
class MyAction(Action):
@property
def reads(self) -> list[str]:
return ["counter"]

@type_eraser
def run(self, state: State, increment_by: int) -> dict:
return {"counter": state["counter"] + increment_by}

@property
def writes(self) -> list[str]:
return ["counter"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)

@property
def inputs(self) -> list[str]:
return ["increment_by"]

a = MyAction()
result = a.run(State({"counter": 0}), increment_by=5)
assert result == {"counter": 5}
assert a.is_async() is False

def test_async_run(self):
class MyAsyncAction(Action):
@property
def reads(self) -> list[str]:
return ["counter"]

@type_eraser
async def run(self, state: State, increment_by: int) -> dict:
return {"counter": state["counter"] + increment_by}

@property
def writes(self) -> list[str]:
return ["counter"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)

@property
def inputs(self) -> list[str]:
return ["increment_by"]

a = MyAsyncAction()
assert a.is_async() is True
result = asyncio.run(a.run(State({"counter": 0}), increment_by=5))
assert result == {"counter": 5}

def test_sync_stream_run(self):
class MyStreamingAction(StreamingAction):
@property
def reads(self) -> list[str]:
return ["items"]

@type_eraser
def stream_run(self, state: State, prefix: str) -> Generator[dict, None, None]:
for item in state["items"]:
yield {"val": f"{prefix}_{item}"}

@property
def writes(self) -> list[str]:
return ["result"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)

@property
def inputs(self) -> list[str]:
return ["prefix"]

a = MyStreamingAction()
results = list(a.stream_run(State({"items": ["a", "b"]}), prefix="x"))
assert results == [{"val": "x_a"}, {"val": "x_b"}]

def test_async_stream_run(self):
class MyAsyncStreamingAction(AsyncStreamingAction):
@property
def reads(self) -> list[str]:
return ["items"]

@type_eraser
async def stream_run(self, state: State, prefix: str) -> AsyncGenerator[dict, None]:
for item in state["items"]:
yield {"val": f"{prefix}_{item}"}

@property
def writes(self) -> list[str]:
return ["result"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)

@property
def inputs(self) -> list[str]:
return ["prefix"]

a = MyAsyncStreamingAction()
assert a.is_async() is True

async def collect():
return [item async for item in a.stream_run(State({"items": ["a", "b"]}), prefix="x")]

results = asyncio.run(collect())
assert results == [{"val": "x_a"}, {"val": "x_b"}]

def test_preserves_wrapped_name(self):
class MyAction(Action):
@property
def reads(self) -> list[str]:
return []

@type_eraser
def run(self, state: State, custom_param: str) -> dict:
return {}

@property
def writes(self) -> list[str]:
return []

def update(self, result: dict, state: State) -> State:
return state

@property
def inputs(self) -> list[str]:
return []

assert MyAction().run.__name__ == "run"
assert MyAction().run.__wrapped__.__name__ == "run"

def test_exported_from_burr_core(self):
from burr.core import type_eraser as te

assert te is type_eraser