Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,22 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_pipeline_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> Exception:
"""Callback executed when the runner pipeline encounters an error.

This callback provides an opportunity to handle pipeline errors globally.

Args:
invocation_context: The context for the entire invocation.
error: The exception that was raised during runner execution.

Returns:
An Exception to be raised (either the original error or a new/modified one).
"""
return error
22 changes: 22 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_pipeline_error_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -272,6 +273,27 @@ async def run_on_tool_error_callback(
error=error,
)

async def run_on_pipeline_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> Exception:
"""Runs the `on_pipeline_error_callback` for all plugins sequentially, chaining the error."""
for plugin in self.plugins:
try:
error = await plugin.on_pipeline_error_callback(
invocation_context=invocation_context, error=error
)
except Exception as e:
error_message = (
f"Error in plugin '{plugin.name}' during "
f"'on_pipeline_error_callback' callback: {e}"
)
logger.error(error_message, exc_info=True)
raise RuntimeError(error_message) from e
return error

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down
217 changes: 112 additions & 105 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,118 +1370,125 @@ async def _exec_with_plugin(

plugin_manager = invocation_context.plugin_manager

# Step 1: Run the before_run callbacks to see if we should early exit.
early_exit_result = await plugin_manager.run_before_run_callback(
invocation_context=invocation_context
)
if isinstance(early_exit_result, types.Content):
early_exit_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
content=early_exit_result,
)
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
try:
# Step 1: Run the before_run callbacks to see if we should early exit.
early_exit_result = await plugin_manager.run_before_run_callback(
invocation_context=invocation_context
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session,
event=early_exit_event,
if isinstance(early_exit_result, types.Content):
early_exit_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
content=early_exit_result,
)
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
# Note for live/bidi:
# the transcription may arrive later than the action(function call
# event and thus function response event). In this case, the order of
# transcription and function call event will be wrong if we just
# append as it arrives. To address this, we should check if there is
# transcription going on. If there is transcription going on, we
# should hold on appending the function call event until the
# transcription is finished. The transcription in progress can be
# identified by checking if the transcription event is partial. When
# the next transcription event is not partial, it means the previous
# transcription is finished. Then if there is any buffered function
# call event, we should append them after this finished(non-partial)
# transcription event.
buffered_events: list[Event] = []
is_transcribing: bool = False

async with aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
# Step 3: Run the on_event callbacks before persisting so callback
# changes are stored in the session and match the streamed event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
output_event = self._get_output_event(
original_event=event,
modified_event=modified_event,
run_config=invocation_context.run_config,
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session,
event=early_exit_event,
)
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
# Note for live/bidi:
# the transcription may arrive later than the action(function call
# event and thus function response event). In this case, the order of
# transcription and function call event will be wrong if we just
# append as it arrives. To address this, we should check if there is
# transcription going on. If there is transcription going on, we
# should hold on appending the function call event until the
# transcription is finished. The transcription in progress can be
# identified by checking if the transcription event is partial. When
# the next transcription event is not partial, it means the previous
# transcription is finished. Then if there is any buffered function
# call event, we should append them after this finished(non-partial)
# transcription event.
buffered_events: list[Event] = []
is_transcribing: bool = False

async with aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
# Step 3: Run the on_event callbacks before persisting so callback
# changes are stored in the session and match the streamed event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
output_event = self._get_output_event(
original_event=event,
modified_event=modified_event,
run_config=invocation_context.run_config,
)

if is_live_call:
if event.partial and _is_transcription(event):
is_transcribing = True
if is_transcribing and _is_tool_call_or_response(event):
# only buffer function call and function response event which is
# non-partial
buffered_events.append(output_event)
continue
# Note for live/bidi: for audio response, it's considered as
# non-partial event(event.partial=None)
# event.partial=False and event.partial=None are considered as
# non-partial event; event.partial=True is considered as partial
# event.
if event.partial is not True:
if _is_transcription(event) and (
_has_non_empty_transcription_text(event.input_transcription)
or _has_non_empty_transcription_text(
event.output_transcription
if is_live_call:
if event.partial and _is_transcription(event):
is_transcribing = True
if is_transcribing and _is_tool_call_or_response(event):
# only buffer function call and function response event which is
# non-partial
buffered_events.append(output_event)
continue
# Note for live/bidi: for audio response, it's considered as
# non-partial event(event.partial=None)
# event.partial=False and event.partial=None are considered as
# non-partial event; event.partial=True is considered as partial
# event.
if event.partial is not True:
if _is_transcription(event) and (
_has_non_empty_transcription_text(event.input_transcription)
or _has_non_empty_transcription_text(
event.output_transcription
)
):
# transcription end signal, append buffered events
is_transcribing = False
logger.debug(
'Appending transcription finished event: %s', event
)
):
# transcription end signal, append buffered events
is_transcribing = False
logger.debug(
'Appending transcription finished event: %s', event
if self._should_append_event(event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

for buffered_event in buffered_events:
logger.debug('Appending buffered event: %s', buffered_event)
await self.session_service.append_event(
session=invocation_context.session, event=buffered_event
)
yield buffered_event # yield buffered events to caller
buffered_events = []
else:
# non-transcription event or empty transcription event, for
# example, event that stores blob reference, should be appended.
if self._should_append_event(event, is_live_call):
logger.debug('Appending non-buffered event: %s', event)
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
else:
if event.partial is not True:
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
if self._should_append_event(event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

for buffered_event in buffered_events:
logger.debug('Appending buffered event: %s', buffered_event)
await self.session_service.append_event(
session=invocation_context.session, event=buffered_event
)
yield buffered_event # yield buffered events to caller
buffered_events = []
else:
# non-transcription event or empty transcription event, for
# example, event that stores blob reference, should be appended.
if self._should_append_event(event, is_live_call):
logger.debug('Appending non-buffered event: %s', event)
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
else:
if event.partial is not True:
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

yield output_event

# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
# This does NOT emit any event.
await plugin_manager.run_after_run_callback(
invocation_context=invocation_context
)
yield output_event
except Exception as e:
if plugin_manager:
e = await plugin_manager.run_on_pipeline_error_callback(
invocation_context=invocation_context, error=e
)
raise e
finally:
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
# This does NOT emit any event.
await plugin_manager.run_after_run_callback(
invocation_context=invocation_context
)

async def _append_new_message_to_session(
self,
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ async def after_model_callback(self, **kwargs):
async def on_model_error_callback(self, **kwargs):
return await self._handle_callback("on_model_error_callback")

async def on_pipeline_error_callback(self, error: Exception, **kwargs):
self.call_log.append("on_pipeline_error_callback")
if "on_pipeline_error_callback" in self.exceptions_to_raise:
raise self.exceptions_to_raise["on_pipeline_error_callback"]
return self.return_values.get("on_pipeline_error_callback", error)


@pytest.fixture
def service() -> PluginManager:
Expand Down Expand Up @@ -252,6 +258,10 @@ async def test_all_callbacks_are_supported(
llm_request=mock_context,
error=mock_context,
)
await service.run_on_pipeline_error_callback(
invocation_context=mock_context,
error=ValueError("err"),
)

# Verify all callbacks were logged
expected_callbacks = [
Expand All @@ -267,6 +277,7 @@ async def test_all_callbacks_are_supported(
"before_model_callback",
"after_model_callback",
"on_model_error_callback",
"on_pipeline_error_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)

Expand Down Expand Up @@ -363,3 +374,43 @@ async def test_set_skip_closing_plugins_false_reverts_to_closing(
await service.close()

plugin1.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_pipeline_error_callback_chaining(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that on_pipeline_error_callback is called and errors are chained."""
error1 = ValueError("Original error")
error2 = RuntimeError("Chained error")
plugin1.return_values["on_pipeline_error_callback"] = error2

service.register_plugin(plugin1)
service.register_plugin(plugin2)

result_err = await service.run_on_pipeline_error_callback(
invocation_context=Mock(), error=error1
)

assert result_err is error2
assert "on_pipeline_error_callback" in plugin1.call_log
assert "on_pipeline_error_callback" in plugin2.call_log


@pytest.mark.asyncio
async def test_pipeline_error_callback_exception_wrap(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that if on_pipeline_error_callback raises, it wraps in RuntimeError."""
plugin1.exceptions_to_raise["on_pipeline_error_callback"] = ValueError(
"Callback crashed"
)
service.register_plugin(plugin1)

with pytest.raises(RuntimeError) as excinfo:
await service.run_on_pipeline_error_callback(
invocation_context=Mock(), error=ValueError("Original")
)

assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "on_pipeline_error_callback" in str(excinfo.value)