diff --git a/burr/core/application.py b/burr/core/application.py index dc8067c4b..296faf36a 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -338,31 +338,41 @@ def _run_single_step_streaming_action( result = None state_update = None count = 0 - for item in generator: - if not isinstance(item, tuple): - # TODO -- consider adding support for just returning a result. - raise ValueError( - f"Action {action.name} must yield a tuple of (result, state_update). " - f"For all non-final results (intermediate)," - f"the state update must be None" - ) - result, state_update = item - count += 1 + try: + for item in generator: + if not isinstance(item, tuple): + # TODO -- consider adding support for just returning a result. + raise ValueError( + f"Action {action.name} must yield a tuple of (result, state_update). " + f"For all non-final results (intermediate)," + f"the state update must be None" + ) + result, state_update = item + if state_update is None: + count += 1 + if first_stream_start_time is None: + first_stream_start_time = system.now() + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_item", + item=result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + yield result, None + except Exception as e: if state_update is None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - lifecycle_adapters.call_all_lifecycle_hooks_sync( - "post_stream_item", - item=result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - yield result, None + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with final state from generator cleanup. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) if state_update is None: raise ValueError( @@ -391,31 +401,42 @@ async def _arun_single_step_streaming_action( result = None state_update = None count = 0 - async for item in generator: - if not isinstance(item, tuple): - # TODO -- consider adding support for just returning a result. - raise ValueError( - f"Action {action.name} must yield a tuple of (result, state_update). " - f"For all non-final results (intermediate)," - f"the state update must be None" - ) - result, state_update = item + try: + async for item in generator: + if not isinstance(item, tuple): + # TODO -- consider adding support for just returning a result. + raise ValueError( + f"Action {action.name} must yield a tuple of (result, state_update). " + f"For all non-final results (intermediate)," + f"the state update must be None" + ) + result, state_update = item + if state_update is None: + count += 1 + if first_stream_start_time is None: + first_stream_start_time = system.now() + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_item", + item=result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + yield result, None + except Exception as e: if state_update is None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( - "post_stream_item", - item=result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield result, None + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with final state from generator cleanup. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) + if state_update is None: raise ValueError( f"Action {action.name} did not return a state update. For async actions, the last yield " @@ -450,28 +471,39 @@ def _run_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - for item in generator: - # We want to peek ahead so we can return the last one - # This is slightly eager, but only in the case in which we - # are using a multi-step streaming action - next_result = result - result = item - if next_result is not None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - lifecycle_adapters.call_all_lifecycle_hooks_sync( - "post_stream_item", - item=next_result, - item_index=count, - stream_initialize_time=stream_initialize_time, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield next_result, None + try: + for item in generator: + # We want to peek ahead so we can return the last one + # This is slightly eager, but only in the case in which we + # are using a multi-step streaming action + next_result = result + result = item + if next_result is not None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_item", + item=next_result, + item_index=count, + stream_initialize_time=stream_initialize_time, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield next_result, None + except Exception as e: + if result is None: + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with last yielded result for reducer. " + "Note: the reducer will run on potentially partial data. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) @@ -494,28 +526,39 @@ async def _arun_multi_step_streaming_action( result = None first_stream_start_time = None count = 0 - async for item in generator: - # We want to peek ahead so we can return the last one - # This is slightly eager, but only in the case in which we - # are using a multi-step streaming action - next_result = result - result = item - if next_result is not None: - if first_stream_start_time is None: - first_stream_start_time = system.now() - await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( - "post_stream_item", - item=next_result, - stream_initialize_time=stream_initialize_time, - item_index=count, - first_stream_item_start_time=first_stream_start_time, - action=action.name, - app_id=app_id, - partition_key=partition_key, - sequence_id=sequence_id, - ) - count += 1 - yield next_result, None + try: + async for item in generator: + # We want to peek ahead so we can return the last one + # This is slightly eager, but only in the case in which we + # are using a multi-step streaming action + next_result = result + result = item + if next_result is not None: + if first_stream_start_time is None: + first_stream_start_time = system.now() + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_item", + item=next_result, + stream_initialize_time=stream_initialize_time, + item_index=count, + first_stream_item_start_time=first_stream_start_time, + action=action.name, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + count += 1 + yield next_result, None + except Exception as e: + if result is None: + raise + logger.warning( + "Streaming action '%s' raised %s after yielding %d items. " + "Proceeding with last yielded result for reducer. " + "Note: the reducer will run on potentially partial data. Original error: %s", + action.name, type(e).__name__, count, e, + exc_info=True, + ) state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) diff --git a/tests/core/test_application.py b/tests/core/test_application.py index c90c40676..587c151f1 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1275,6 +1275,296 @@ async def post_stream_item(self, item: Any, **future_kwargs: Any): assert len(hook.items) == 10 # one for each streaming callback +class SingleStepStreamingCounterWithException(SingleStepStreamingAction): + """Yields intermediate items, raises, then yields final state in finally block.""" + + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionNoState(SingleStepStreamingAction): + """Raises without ever yielding a final state update.""" + + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + count = state["count"] + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure with no state") + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionAsync(SingleStepStreamingAction): + """Async variant: yields intermediate items, raises, then yields final state in finally.""" + + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class SingleStepStreamingCounterWithExceptionNoStateAsync(SingleStepStreamingAction): + """Async variant: raises without ever yielding a final state update.""" + + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + count = state["count"] + for i in range(3): + yield {"count": count + ((i + 1) / 10)}, None + raise RuntimeError("simulated failure with no state") + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + +class MultiStepStreamingCounterWithException(StreamingAction): + """Yields intermediate items, raises, then yields final result in finally block.""" + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)} + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionNoResult(StreamingAction): + """Raises without ever yielding any item.""" + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + raise RuntimeError("simulated failure with no result") + yield # make this a generator function + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction): + """Async variant: yields intermediate items, raises, then yields final result in finally.""" + + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + count = state["count"] + try: + for i in range(3): + yield {"count": count + ((i + 1) / 10)} + raise RuntimeError("simulated failure") + finally: + yield {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction): + """Async variant: raises without ever yielding any item.""" + + async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + raise RuntimeError("simulated failure with no result") + yield # make this an async generator + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +def test__run_single_step_streaming_action_graceful_exception(): + """When the generator raises but yields a final state in finally, stream completes gracefully.""" + action = SingleStepStreamingCounterWithException().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + results = list(generator) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +def test__run_single_step_streaming_action_exception_propagates(): + """When the generator raises without yielding a final state, exception propagates.""" + action = SingleStepStreamingCounterWithExceptionNoState().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + with pytest.raises(RuntimeError, match="simulated failure with no state"): + list(generator) + + +async def test__run_single_step_streaming_action_graceful_exception_async(): + """Async: when the generator raises but yields a final state in finally, stream completes.""" + action = SingleStepStreamingCounterWithExceptionAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + results = [] + async for item in generator: + results.append(item) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +async def test__run_single_step_streaming_action_exception_propagates_async(): + """Async: when the generator raises without yielding a final state, exception propagates.""" + action = SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + with pytest.raises(RuntimeError, match="simulated failure with no state"): + async for _ in generator: + pass + + +def test__run_multi_step_streaming_action_graceful_exception(): + """When the generator raises but yields a final result in finally, stream completes.""" + action = MultiStepStreamingCounterWithException().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + results = list(generator) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +def test__run_multi_step_streaming_action_exception_propagates(): + """When the generator raises without yielding any result, exception propagates.""" + action = MultiStepStreamingCounterWithExceptionNoResult().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, state, inputs={}, sequence_id=0, partition_key="pk", app_id="app" + ) + with pytest.raises(RuntimeError, match="simulated failure with no result"): + list(generator) + + +async def test__run_multi_step_streaming_action_graceful_exception_async(): + """Async: when the generator raises but yields a final result in finally, stream completes.""" + action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + results = [] + async for item in generator: + results.append(item) + intermediate = [(r, s) for r, s in results if s is None] + final = [(r, s) for r, s in results if s is not None] + assert len(intermediate) == 3 + assert len(final) == 1 + assert final[0][0] == {"count": 1} + assert final[0][1].subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +async def test__run_multi_step_streaming_action_exception_propagates_async(): + """Async: when the generator raises without yielding any result, exception propagates.""" + action = MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, state=state, inputs={}, sequence_id=0, app_id="app", partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(), + ) + with pytest.raises(RuntimeError, match="simulated failure with no result"): + async for _ in generator: + pass + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"])