From ac68e1ecb265bc436dc2273f89585d8960a768bc Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 22 May 2026 09:35:22 +0900 Subject: [PATCH] feat: add pre-approval tool input guardrails --- src/agents/realtime/__init__.py | 2 + src/agents/realtime/config.py | 13 ++ src/agents/realtime/session.py | 105 ++++++++++++- src/agents/run_config.py | 8 + src/agents/run_internal/tool_execution.py | 36 +++++ tests/realtime/test_session.py | 175 ++++++++++++++++++++++ tests/test_agent_runner.py | 129 ++++++++++++++++ tests/test_source_compat_constructors.py | 7 + 8 files changed, 471 insertions(+), 4 deletions(-) diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index cd1702260a..8e3db27c25 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -11,6 +11,7 @@ RealtimeReasoningEffort, RealtimeRunConfig, RealtimeSessionModelSettings, + RealtimeToolExecutionConfig, RealtimeTurnDetectionConfig, RealtimeUserInput, RealtimeUserInputMessage, @@ -114,6 +115,7 @@ "RealtimeReasoningEffort", "RealtimeRunConfig", "RealtimeSessionModelSettings", + "RealtimeToolExecutionConfig", "RealtimeTurnDetectionConfig", "RealtimeUserInput", "RealtimeUserInputMessage", diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index defd4428b4..73ab0c094f 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -221,6 +221,16 @@ class RealtimeGuardrailsSettings(TypedDict): """ +class RealtimeToolExecutionConfig(TypedDict): + """SDK-side execution settings for local realtime tool calls.""" + + pre_approval_tool_input_guardrails: NotRequired[bool] + """Run function tool input guardrails before emitting a pending approval event. + + The same guardrails still run again immediately before tool execution after approval. + """ + + class RealtimeModelTracingConfig(TypedDict): """Configuration for tracing in realtime model sessions.""" @@ -252,6 +262,9 @@ class RealtimeRunConfig(TypedDict): async_tool_calls: NotRequired[bool] """Whether function tool calls should run asynchronously. Defaults to True.""" + tool_execution: NotRequired[RealtimeToolExecutionConfig] + """SDK-side execution settings for local realtime tool calls.""" + tool_error_formatter: NotRequired[ToolErrorFormatter] """Optional callback that formats tool error messages returned to the model.""" diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index ca809dd9c4..b8eec22a37 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -4,7 +4,7 @@ import dataclasses import inspect import json -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from typing import Any, cast from pydantic import BaseModel @@ -16,7 +16,7 @@ get_function_tool_namespace, ) from ..agent import Agent -from ..exceptions import UserError +from ..exceptions import ToolInputGuardrailTripwireTriggered, UserError from ..handoffs import Handoff from ..items import ToolApprovalItem from ..logger import logger @@ -24,6 +24,7 @@ from ..run_context import RunContextWrapper, TContext from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool from ..tool_context import ToolContext +from ..tool_guardrails import ToolInputGuardrailData from ..util._approvals import evaluate_needs_approval_setting from .agent import RealtimeAgent from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput @@ -520,8 +521,8 @@ async def _maybe_request_tool_approval( *, function_tool: FunctionTool, agent: RealtimeAgent, - ) -> bool | None: - """Return True/False when approved/rejected, or None when awaiting approval.""" + ) -> bool | None | _PendingToolOutput: + """Return approval status, pending output for guardrail rejection, or None when awaiting.""" tool_lookup_key = get_function_tool_lookup_key_for_tool(function_tool) approval_item = self._build_tool_approval_item( function_tool, @@ -545,6 +546,20 @@ async def _maybe_request_tool_approval( if approval_status is False: return False + if self._pre_approval_tool_input_guardrails_enabled(): + rejected_message = await self._run_tool_input_guardrails( + tool=function_tool, + tool_call=tool_call, + agent=agent, + ) + if rejected_message is not None: + return self._build_realtime_tool_output( + tool=function_tool, + tool_call=tool_call, + agent=agent, + output=rejected_message, + ) + self._pending_tool_calls[tool_call.call_id] = ( tool_call, agent, @@ -562,6 +577,67 @@ async def _maybe_request_tool_approval( ) return None + def _pre_approval_tool_input_guardrails_enabled(self) -> bool: + return ( + self._run_config.get("tool_execution", {}).get( + "pre_approval_tool_input_guardrails", False + ) + is True + ) + + async def _run_tool_input_guardrails( + self, + *, + tool: FunctionTool, + tool_call: RealtimeModelToolCallEvent, + agent: RealtimeAgent, + ) -> str | None: + """Run function tool input guardrails and return rejection output when blocked.""" + guardrails = tool.tool_input_guardrails + if isinstance(guardrails, str | bytes) or not isinstance(guardrails, Sequence): + return None + if not guardrails: + return None + + tool_context = ToolContext( + context=self._context_wrapper.context, + usage=self._context_wrapper.usage, + tool_name=tool_call.name, + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + agent=agent, + ) + for guardrail in guardrails: + gr_out = await guardrail.run( + ToolInputGuardrailData(context=tool_context, agent=cast(Agent[Any], agent)) + ) + if gr_out.behavior["type"] == "raise_exception": + raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out) + if gr_out.behavior["type"] == "reject_content": + return gr_out.behavior["message"] + return None + + def _build_realtime_tool_output( + self, + *, + tool: FunctionTool, + tool_call: RealtimeModelToolCallEvent, + agent: RealtimeAgent, + output: str, + ) -> _PendingToolOutput: + return _PendingToolOutput( + tool_call=tool_call, + output=output, + start_response=True, + tool_end_event=RealtimeToolEnd( + info=self._event_info, + tool=tool, + output=output, + agent=agent, + arguments=tool_call.arguments, + ), + ) + async def _send_tool_rejection( self, event: RealtimeModelToolCallEvent, @@ -749,6 +825,10 @@ async def _handle_tool_call( approval_status = await self._maybe_request_tool_approval( event, function_tool=func_tool, agent=agent ) + if isinstance(approval_status, _PendingToolOutput): + await self._send_tool_output_completion(approval_status) + mark_completed = True + return if approval_status is False: await self._send_tool_rejection(event, tool=func_tool, agent=agent) mark_completed = True @@ -756,6 +836,23 @@ async def _handle_tool_call( if approval_status is None: return + rejected_message = await self._run_tool_input_guardrails( + tool=func_tool, + tool_call=event, + agent=agent, + ) + if rejected_message is not None: + await self._send_tool_output_completion( + self._build_realtime_tool_output( + tool=func_tool, + tool_call=event, + agent=agent, + output=rejected_message, + ) + ) + mark_completed = True + return + await self._put_event( RealtimeToolStart( info=self._event_info, diff --git a/src/agents/run_config.py b/src/agents/run_config.py index fcc9b01315..45dcca5b10 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -103,11 +103,19 @@ class ToolExecutionConfig: emitted in a turn. This does not change provider-side `parallel_tool_calls` behavior. """ + pre_approval_tool_input_guardrails: bool = False + """Run function tool input guardrails before emitting a pending approval interruption. + + The same guardrails still run again immediately before tool execution after approval. + """ + def __post_init__(self) -> None: if self.max_function_tool_concurrency is not None and ( self.max_function_tool_concurrency < 1 ): raise ValueError("tool_execution.max_function_tool_concurrency must be at least 1") + if not isinstance(self.pre_approval_tool_input_guardrails, bool): + raise ValueError("tool_execution.pre_approval_tool_input_guardrails must be a bool") @dataclass diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 8f30e4a01f..4fbd923b3b 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -1651,6 +1651,36 @@ async def _maybe_execute_tool_approval( tool_lookup_key=tool_lookup_key, ) if approval_status is None: + if self._should_run_pre_approval_tool_input_guardrails(): + tool_context_namespace = get_tool_call_namespace(raw_tool_call) + if tool_context_namespace is None: + tool_context_namespace = get_tool_call_namespace(tool_call) + tool_context = ToolContext.from_agent_context( + self.context_wrapper, + tool_call.call_id, + tool_call=raw_tool_call, + tool_namespace=tool_context_namespace, + agent=self.public_agent, + run_config=self.config, + ) + rejected_message = await _execute_tool_input_guardrails( + func_tool=func_tool, + tool_context=tool_context, + agent=self.public_agent, + tool_input_guardrail_results=self.tool_input_guardrail_results, + ) + if rejected_message is not None: + return FunctionToolResult( + tool=func_tool, + output=rejected_message, + run_item=function_rejection_item( + self.public_agent, + tool_call, + rejection_message=rejected_message, + scope_id=self.tool_state_scope_id, + tool_origin=get_function_tool_origin(func_tool), + ), + ) approval_item = ToolApprovalItem( agent=self.public_agent, raw_item=raw_tool_call, @@ -1742,6 +1772,12 @@ async def _execute_single_tool_body( task_state.invoke_task = invoke_task return await self._await_invoke_task(outer_task=outer_task, invoke_task=invoke_task) + def _should_run_pre_approval_tool_input_guardrails(self) -> bool: + tool_execution = self.config.tool_execution + if tool_execution is None: + return False + return tool_execution.pre_approval_tool_input_guardrails + async def _invoke_tool_and_run_post_invoke( self, *, diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index e289bc3c9e..b9716bee68 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -64,6 +64,11 @@ from agents.run_context import RunContextWrapper from agents.tool import FunctionTool, tool_namespace from agents.tool_context import ToolContext +from agents.tool_guardrails import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrailData, + tool_input_guardrail, +) class _DummyModel(RealtimeModel): @@ -1442,6 +1447,176 @@ async def test_function_tool_needs_approval_emits_event( assert approval_event.call_id == tool_call_event.call_id assert approval_event.tool == mock_function_tool + @pytest.mark.asyncio + async def test_tool_input_guardrail_rejects_before_realtime_function_execution( + self, mock_model + ): + """Tool input guardrails should run before regular realtime function tool execution.""" + executed = False + + @tool_input_guardrail + def reject_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.reject_content("blocked before execution") + + async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + nonlocal executed + executed = True + return "ok" + + guarded_tool = FunctionTool( + name="test_function", + description="guarded", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_tool, + tool_input_guardrails=[reject_guardrail], + ) + agent = RealtimeAgent(name="agent", tools=[guarded_tool]) + session = RealtimeSession(mock_model, agent, None, run_config={"async_tool_calls": False}) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_guardrail_reject", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + assert executed is False + assert len(mock_model.sent_tool_outputs) == 1 + _sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_output == "blocked before execution" + assert start_response is True + + @pytest.mark.asyncio + async def test_realtime_pending_approval_skips_tool_input_guardrails_by_default( + self, mock_model + ): + guardrail_runs = 0 + + @tool_input_guardrail + def count_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + nonlocal guardrail_runs + guardrail_runs += 1 + return ToolGuardrailFunctionOutput.allow() + + async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + return "ok" + + guarded_tool = FunctionTool( + name="test_function", + description="guarded", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_tool, + needs_approval=True, + tool_input_guardrails=[count_guardrail], + ) + agent = RealtimeAgent(name="agent", tools=[guarded_tool]) + session = RealtimeSession(mock_model, agent, None, run_config={"async_tool_calls": False}) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_guardrail_pending", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + assert tool_call_event.call_id in session._pending_tool_calls + assert guardrail_runs == 0 + + @pytest.mark.asyncio + async def test_realtime_pre_approval_tool_input_guardrail_rejects_pending_approval( + self, mock_model + ): + executed = False + + @tool_input_guardrail + def reject_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.reject_content("blocked before approval") + + async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + nonlocal executed + executed = True + return "ok" + + guarded_tool = FunctionTool( + name="test_function", + description="guarded", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_tool, + needs_approval=True, + tool_input_guardrails=[reject_guardrail], + ) + agent = RealtimeAgent(name="agent", tools=[guarded_tool]) + session = RealtimeSession( + mock_model, + agent, + None, + run_config={ + "async_tool_calls": False, + "tool_execution": {"pre_approval_tool_input_guardrails": True}, + }, + ) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_pre_approval_reject", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + assert executed is False + assert tool_call_event.call_id not in session._pending_tool_calls + assert len(mock_model.sent_tool_outputs) == 1 + _sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_output == "blocked before approval" + assert start_response is True + + @pytest.mark.asyncio + async def test_realtime_pre_approval_tool_input_guardrails_rerun_after_approval( + self, mock_model + ): + guardrail_runs = 0 + executed = 0 + + @tool_input_guardrail + def count_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + nonlocal guardrail_runs + guardrail_runs += 1 + return ToolGuardrailFunctionOutput.allow() + + async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + nonlocal executed + executed += 1 + return "ok" + + guarded_tool = FunctionTool( + name="test_function", + description="guarded", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_tool, + needs_approval=True, + tool_input_guardrails=[count_guardrail], + ) + agent = RealtimeAgent(name="agent", tools=[guarded_tool]) + session = RealtimeSession( + mock_model, + agent, + None, + run_config={ + "async_tool_calls": False, + "tool_execution": {"pre_approval_tool_input_guardrails": True}, + }, + ) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_pre_approval_rerun", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + assert guardrail_runs == 1 + assert executed == 0 + + await session.approve_tool_call(tool_call_event.call_id) + + assert guardrail_runs == 2 + assert executed == 1 + assert len(mock_model.sent_tool_outputs) == 1 + _sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_output == "ok" + assert start_response is True + @pytest.mark.asyncio async def test_duplicate_pending_approval_call_id_is_ignored_and_approval_runs_once( self, mock_model, mock_agent, mock_function_tool diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index eb22c70f14..4b5ea867ce 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -35,10 +35,14 @@ RunContextWrapper, Runner, SQLiteSession, + ToolExecutionConfig, + ToolGuardrailFunctionOutput, + ToolInputGuardrailData, ToolTimeoutError, UserError, handoff, retry_policies, + tool_input_guardrail, tool_namespace, ) from agents.agent import ToolsToFinalOutputResult @@ -843,6 +847,131 @@ def approval_tool() -> str: assert "id" not in second_request_reasoning +@pytest.mark.asyncio +async def test_pending_approval_skips_tool_input_guardrails_by_default() -> None: + model = FakeModel() + guardrail_runs = 0 + + @tool_input_guardrail + def count_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + nonlocal guardrail_runs + guardrail_runs += 1 + return ToolGuardrailFunctionOutput.allow() + + @function_tool( + name_override="approval_tool", + needs_approval=True, + tool_input_guardrails=[count_guardrail], + ) + def approval_tool() -> str: + return "ok" + + agent = Agent(name="test", model=model, tools=[approval_tool]) + model.set_next_output([get_function_tool_call("approval_tool", "{}", call_id="call_default")]) + + result = await Runner.run(agent, "hello") + + assert len(result.interruptions) == 1 + assert guardrail_runs == 0 + assert result.tool_input_guardrail_results == [] + + +@pytest.mark.asyncio +async def test_pre_approval_tool_input_guardrails_can_reject_before_pending_approval() -> None: + model = FakeModel() + executed = False + + @tool_input_guardrail + def reject_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.reject_content("blocked before approval") + + @function_tool( + name_override="approval_tool", + needs_approval=True, + tool_input_guardrails=[reject_guardrail], + ) + def approval_tool() -> str: + nonlocal executed + executed = True + return "ok" + + agent = Agent(name="test", model=model, tools=[approval_tool]) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", "{}", call_id="call_reject")], + [get_text_message("done")], + ] + ) + + result = await Runner.run( + agent, + "hello", + run_config=RunConfig( + tool_execution=ToolExecutionConfig(pre_approval_tool_input_guardrails=True) + ), + ) + + assert result.final_output == "done" + assert result.interruptions == [] + assert executed is False + assert len(result.tool_input_guardrail_results) == 1 + assert any( + isinstance(item, ToolCallOutputItem) and item.output == "blocked before approval" + for item in result.new_items + ) + + +@pytest.mark.asyncio +async def test_pre_approval_tool_input_guardrails_rerun_after_resume() -> None: + model = FakeModel() + guardrail_runs = 0 + executed = 0 + + @tool_input_guardrail + def count_guardrail(_data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + nonlocal guardrail_runs + guardrail_runs += 1 + return ToolGuardrailFunctionOutput.allow() + + @function_tool( + name_override="approval_tool", + needs_approval=True, + tool_input_guardrails=[count_guardrail], + ) + def approval_tool() -> str: + nonlocal executed + executed += 1 + return "ok" + + agent = Agent(name="test", model=model, tools=[approval_tool]) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", "{}", call_id="call_resume")], + [get_text_message("done")], + ] + ) + run_config = RunConfig( + tool_execution=ToolExecutionConfig(pre_approval_tool_input_guardrails=True) + ) + + first = await Runner.run(agent, "hello", run_config=run_config) + assert len(first.interruptions) == 1 + assert guardrail_runs == 1 + assert executed == 0 + assert len(first.tool_input_guardrail_results) == 1 + + state = first.to_state() + state.approve(first.interruptions[0]) + restored_state = await RunState.from_string(agent, state.to_string()) + + resumed = await Runner.run(agent, restored_state, run_config=run_config) + + assert resumed.final_output == "done" + assert guardrail_runs == 2 + assert executed == 1 + assert len(resumed.tool_input_guardrail_results) == 1 + + @pytest.mark.asyncio async def test_tool_call_context_includes_current_agent() -> None: model = FakeModel() diff --git a/tests/test_source_compat_constructors.py b/tests/test_source_compat_constructors.py index 2b82b9c8d9..2033285691 100644 --- a/tests/test_source_compat_constructors.py +++ b/tests/test_source_compat_constructors.py @@ -167,6 +167,13 @@ def test_run_config_tool_not_found_behavior_append_preserves_tool_execution_posi assert config.tool_not_found_behavior == "return_error_to_model" +def test_tool_execution_config_pre_approval_append_preserves_max_concurrency() -> None: + config = ToolExecutionConfig(2, True) + + assert config.max_function_tool_concurrency == 2 + assert config.pre_approval_tool_input_guardrails is True + + def test_model_settings_context_management_append_preserves_retry_position() -> None: retry = ModelRetrySettings(max_retries=1) settings = ModelSettings(