diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 54bfab2ed2..fa3bde5447 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,22 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_pipeline_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> Exception: + """Callback executed when the runner pipeline encounters an error. + + This callback provides an opportunity to handle pipeline errors globally. + + Args: + invocation_context: The context for the entire invocation. + error: The exception that was raised during runner execution. + + Returns: + An Exception to be raised (either the original error or a new/modified one). + """ + return error diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index 5566349516..80fa455bf0 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,7 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_pipeline_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -272,6 +273,27 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_pipeline_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> Exception: + """Runs the `on_pipeline_error_callback` for all plugins sequentially, chaining the error.""" + for plugin in self.plugins: + try: + error = await plugin.on_pipeline_error_callback( + invocation_context=invocation_context, error=error + ) + except Exception as e: + error_message = ( + f"Error in plugin '{plugin.name}' during " + f"'on_pipeline_error_callback' callback: {e}" + ) + logger.error(error_message, exc_info=True) + raise RuntimeError(error_message) from e + return error + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 5b2f23fec7..e6f38af67f 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1370,118 +1370,125 @@ async def _exec_with_plugin( plugin_manager = invocation_context.plugin_manager - # Step 1: Run the before_run callbacks to see if we should early exit. - early_exit_result = await plugin_manager.run_before_run_callback( - invocation_context=invocation_context - ) - if isinstance(early_exit_result, types.Content): - early_exit_event = Event( - invocation_id=invocation_context.invocation_id, - author='model', - content=early_exit_result, - ) - _apply_run_config_custom_metadata( - early_exit_event, invocation_context.run_config + try: + # Step 1: Run the before_run callbacks to see if we should early exit. + early_exit_result = await plugin_manager.run_before_run_callback( + invocation_context=invocation_context ) - if self._should_append_event(early_exit_event, is_live_call): - await self.session_service.append_event( - session=invocation_context.session, - event=early_exit_event, + if isinstance(early_exit_result, types.Content): + early_exit_event = Event( + invocation_id=invocation_context.invocation_id, + author='model', + content=early_exit_result, ) - yield early_exit_event - else: - # Step 2: Otherwise continue with normal execution - # Note for live/bidi: - # the transcription may arrive later than the action(function call - # event and thus function response event). In this case, the order of - # transcription and function call event will be wrong if we just - # append as it arrives. To address this, we should check if there is - # transcription going on. If there is transcription going on, we - # should hold on appending the function call event until the - # transcription is finished. The transcription in progress can be - # identified by checking if the transcription event is partial. When - # the next transcription event is not partial, it means the previous - # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-partial) - # transcription event. - buffered_events: list[Event] = [] - is_transcribing: bool = False - - async with aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - _apply_run_config_custom_metadata( - event, invocation_context.run_config - ) - # Step 3: Run the on_event callbacks before persisting so callback - # changes are stored in the session and match the streamed event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - output_event = self._get_output_event( - original_event=event, - modified_event=modified_event, - run_config=invocation_context.run_config, + _apply_run_config_custom_metadata( + early_exit_event, invocation_context.run_config + ) + if self._should_append_event(early_exit_event, is_live_call): + await self.session_service.append_event( + session=invocation_context.session, + event=early_exit_event, ) + yield early_exit_event + else: + # Step 2: Otherwise continue with normal execution + # Note for live/bidi: + # the transcription may arrive later than the action(function call + # event and thus function response event). In this case, the order of + # transcription and function call event will be wrong if we just + # append as it arrives. To address this, we should check if there is + # transcription going on. If there is transcription going on, we + # should hold on appending the function call event until the + # transcription is finished. The transcription in progress can be + # identified by checking if the transcription event is partial. When + # the next transcription event is not partial, it means the previous + # transcription is finished. Then if there is any buffered function + # call event, we should append them after this finished(non-partial) + # transcription event. + buffered_events: list[Event] = [] + is_transcribing: bool = False + + async with aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) + # Step 3: Run the on_event callbacks before persisting so callback + # changes are stored in the session and match the streamed event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event + ) + output_event = self._get_output_event( + original_event=event, + modified_event=modified_event, + run_config=invocation_context.run_config, + ) - if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(output_event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-partial event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(output_event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-partial event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + 'Appending transcription finished event: %s', event ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=invocation_context.session, event=output_event + ) + + for buffered_event in buffered_events: + logger.debug('Appending buffered event: %s', buffered_event) + await self.session_service.append_event( + session=invocation_context.session, event=buffered_event + ) + yield buffered_event # yield buffered events to caller + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug('Appending non-buffered event: %s', event) + await self.session_service.append_event( + session=invocation_context.session, event=output_event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=invocation_context.session, event=output_event ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=invocation_context.session, event=buffered_event - ) - yield buffered_event # yield buffered events to caller - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - - yield output_event - - # Step 4: Run the after_run callbacks to perform global cleanup tasks or - # finalizing logs and metrics data. - # This does NOT emit any event. - await plugin_manager.run_after_run_callback( - invocation_context=invocation_context - ) + yield output_event + except Exception as e: + if plugin_manager: + e = await plugin_manager.run_on_pipeline_error_callback( + invocation_context=invocation_context, error=e + ) + raise e + finally: + # Step 4: Run the after_run callbacks to perform global cleanup tasks or + # finalizing logs and metrics data. + # This does NOT emit any event. + await plugin_manager.run_after_run_callback( + invocation_context=invocation_context + ) async def _append_new_message_to_session( self, diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index 6c72a2a665..668efe8a12 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -91,6 +91,12 @@ async def after_model_callback(self, **kwargs): async def on_model_error_callback(self, **kwargs): return await self._handle_callback("on_model_error_callback") + async def on_pipeline_error_callback(self, error: Exception, **kwargs): + self.call_log.append("on_pipeline_error_callback") + if "on_pipeline_error_callback" in self.exceptions_to_raise: + raise self.exceptions_to_raise["on_pipeline_error_callback"] + return self.return_values.get("on_pipeline_error_callback", error) + @pytest.fixture def service() -> PluginManager: @@ -252,6 +258,10 @@ async def test_all_callbacks_are_supported( llm_request=mock_context, error=mock_context, ) + await service.run_on_pipeline_error_callback( + invocation_context=mock_context, + error=ValueError("err"), + ) # Verify all callbacks were logged expected_callbacks = [ @@ -267,6 +277,7 @@ async def test_all_callbacks_are_supported( "before_model_callback", "after_model_callback", "on_model_error_callback", + "on_pipeline_error_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) @@ -363,3 +374,43 @@ async def test_set_skip_closing_plugins_false_reverts_to_closing( await service.close() plugin1.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_pipeline_error_callback_chaining( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests that on_pipeline_error_callback is called and errors are chained.""" + error1 = ValueError("Original error") + error2 = RuntimeError("Chained error") + plugin1.return_values["on_pipeline_error_callback"] = error2 + + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + result_err = await service.run_on_pipeline_error_callback( + invocation_context=Mock(), error=error1 + ) + + assert result_err is error2 + assert "on_pipeline_error_callback" in plugin1.call_log + assert "on_pipeline_error_callback" in plugin2.call_log + + +@pytest.mark.asyncio +async def test_pipeline_error_callback_exception_wrap( + service: PluginManager, plugin1: TestPlugin +): + """Tests that if on_pipeline_error_callback raises, it wraps in RuntimeError.""" + plugin1.exceptions_to_raise["on_pipeline_error_callback"] = ValueError( + "Callback crashed" + ) + service.register_plugin(plugin1) + + with pytest.raises(RuntimeError) as excinfo: + await service.run_on_pipeline_error_callback( + invocation_context=Mock(), error=ValueError("Original") + ) + + assert "Error in plugin 'plugin1'" in str(excinfo.value) + assert "on_pipeline_error_callback" in str(excinfo.value)