Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/agents/realtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
RealtimeReasoningEffort,
RealtimeRunConfig,
RealtimeSessionModelSettings,
RealtimeToolExecutionConfig,
RealtimeTurnDetectionConfig,
RealtimeUserInput,
RealtimeUserInputMessage,
Expand Down Expand Up @@ -114,6 +115,7 @@
"RealtimeReasoningEffort",
"RealtimeRunConfig",
"RealtimeSessionModelSettings",
"RealtimeToolExecutionConfig",
"RealtimeTurnDetectionConfig",
"RealtimeUserInput",
"RealtimeUserInputMessage",
Expand Down
13 changes: 13 additions & 0 deletions src/agents/realtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ class RealtimeGuardrailsSettings(TypedDict):
"""


class RealtimeToolExecutionConfig(TypedDict):
"""SDK-side execution settings for local realtime tool calls."""

pre_approval_tool_input_guardrails: NotRequired[bool]
"""Run function tool input guardrails before emitting a pending approval event.
The same guardrails still run again immediately before tool execution after approval.
"""


class RealtimeModelTracingConfig(TypedDict):
"""Configuration for tracing in realtime model sessions."""

Expand Down Expand Up @@ -252,6 +262,9 @@ class RealtimeRunConfig(TypedDict):
async_tool_calls: NotRequired[bool]
"""Whether function tool calls should run asynchronously. Defaults to True."""

tool_execution: NotRequired[RealtimeToolExecutionConfig]
"""SDK-side execution settings for local realtime tool calls."""

tool_error_formatter: NotRequired[ToolErrorFormatter]
"""Optional callback that formats tool error messages returned to the model."""

Expand Down
105 changes: 101 additions & 4 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import inspect
import json
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Sequence
from typing import Any, cast

from pydantic import BaseModel
Expand All @@ -16,14 +16,15 @@
get_function_tool_namespace,
)
from ..agent import Agent
from ..exceptions import UserError
from ..exceptions import ToolInputGuardrailTripwireTriggered, UserError
from ..handoffs import Handoff
from ..items import ToolApprovalItem
from ..logger import logger
from ..run_config import ToolErrorFormatterArgs
from ..run_context import RunContextWrapper, TContext
from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool
from ..tool_context import ToolContext
from ..tool_guardrails import ToolInputGuardrailData
from ..util._approvals import evaluate_needs_approval_setting
from .agent import RealtimeAgent
from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput
Expand Down Expand Up @@ -520,8 +521,8 @@ async def _maybe_request_tool_approval(
*,
function_tool: FunctionTool,
agent: RealtimeAgent,
) -> bool | None:
"""Return True/False when approved/rejected, or None when awaiting approval."""
) -> bool | None | _PendingToolOutput:
"""Return approval status, pending output for guardrail rejection, or None when awaiting."""
tool_lookup_key = get_function_tool_lookup_key_for_tool(function_tool)
approval_item = self._build_tool_approval_item(
function_tool,
Expand All @@ -545,6 +546,20 @@ async def _maybe_request_tool_approval(
if approval_status is False:
return False

if self._pre_approval_tool_input_guardrails_enabled():
rejected_message = await self._run_tool_input_guardrails(
tool=function_tool,
tool_call=tool_call,
agent=agent,
)
if rejected_message is not None:
return self._build_realtime_tool_output(
tool=function_tool,
tool_call=tool_call,
agent=agent,
output=rejected_message,
)

self._pending_tool_calls[tool_call.call_id] = (
tool_call,
agent,
Expand All @@ -562,6 +577,67 @@ async def _maybe_request_tool_approval(
)
return None

def _pre_approval_tool_input_guardrails_enabled(self) -> bool:
return (
self._run_config.get("tool_execution", {}).get(
"pre_approval_tool_input_guardrails", False
)
is True
)

async def _run_tool_input_guardrails(
self,
*,
tool: FunctionTool,
tool_call: RealtimeModelToolCallEvent,
agent: RealtimeAgent,
) -> str | None:
"""Run function tool input guardrails and return rejection output when blocked."""
guardrails = tool.tool_input_guardrails
if isinstance(guardrails, str | bytes) or not isinstance(guardrails, Sequence):
return None
if not guardrails:
return None

tool_context = ToolContext(
context=self._context_wrapper.context,
usage=self._context_wrapper.usage,
tool_name=tool_call.name,
tool_call_id=tool_call.call_id,
tool_arguments=tool_call.arguments,
agent=agent,
)
for guardrail in guardrails:
gr_out = await guardrail.run(
ToolInputGuardrailData(context=tool_context, agent=cast(Agent[Any], agent))
)
if gr_out.behavior["type"] == "raise_exception":
raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out)
if gr_out.behavior["type"] == "reject_content":
return gr_out.behavior["message"]
return None

def _build_realtime_tool_output(
self,
*,
tool: FunctionTool,
tool_call: RealtimeModelToolCallEvent,
agent: RealtimeAgent,
output: str,
) -> _PendingToolOutput:
return _PendingToolOutput(
tool_call=tool_call,
output=output,
start_response=True,
tool_end_event=RealtimeToolEnd(
info=self._event_info,
tool=tool,
output=output,
agent=agent,
arguments=tool_call.arguments,
),
)

async def _send_tool_rejection(
self,
event: RealtimeModelToolCallEvent,
Expand Down Expand Up @@ -749,13 +825,34 @@ async def _handle_tool_call(
approval_status = await self._maybe_request_tool_approval(
event, function_tool=func_tool, agent=agent
)
if isinstance(approval_status, _PendingToolOutput):
await self._send_tool_output_completion(approval_status)
mark_completed = True
return
if approval_status is False:
await self._send_tool_rejection(event, tool=func_tool, agent=agent)
mark_completed = True
return
if approval_status is None:
return

rejected_message = await self._run_tool_input_guardrails(
tool=func_tool,
tool_call=event,
agent=agent,
)
if rejected_message is not None:
await self._send_tool_output_completion(
self._build_realtime_tool_output(
tool=func_tool,
tool_call=event,
agent=agent,
output=rejected_message,
)
)
mark_completed = True
return

await self._put_event(
RealtimeToolStart(
info=self._event_info,
Expand Down
8 changes: 8 additions & 0 deletions src/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,19 @@ class ToolExecutionConfig:
emitted in a turn. This does not change provider-side `parallel_tool_calls` behavior.
"""

pre_approval_tool_input_guardrails: bool = False
"""Run function tool input guardrails before emitting a pending approval interruption.

The same guardrails still run again immediately before tool execution after approval.
"""

def __post_init__(self) -> None:
if self.max_function_tool_concurrency is not None and (
self.max_function_tool_concurrency < 1
):
raise ValueError("tool_execution.max_function_tool_concurrency must be at least 1")
if not isinstance(self.pre_approval_tool_input_guardrails, bool):
raise ValueError("tool_execution.pre_approval_tool_input_guardrails must be a bool")


@dataclass
Expand Down
36 changes: 36 additions & 0 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,36 @@ async def _maybe_execute_tool_approval(
tool_lookup_key=tool_lookup_key,
)
if approval_status is None:
if self._should_run_pre_approval_tool_input_guardrails():
tool_context_namespace = get_tool_call_namespace(raw_tool_call)
if tool_context_namespace is None:
tool_context_namespace = get_tool_call_namespace(tool_call)
tool_context = ToolContext.from_agent_context(
self.context_wrapper,
tool_call.call_id,
tool_call=raw_tool_call,
tool_namespace=tool_context_namespace,
agent=self.public_agent,
run_config=self.config,
)
rejected_message = await _execute_tool_input_guardrails(
func_tool=func_tool,
tool_context=tool_context,
agent=self.public_agent,
tool_input_guardrail_results=self.tool_input_guardrail_results,
)
if rejected_message is not None:
return FunctionToolResult(
tool=func_tool,
output=rejected_message,
run_item=function_rejection_item(
self.public_agent,
tool_call,
rejection_message=rejected_message,
scope_id=self.tool_state_scope_id,
tool_origin=get_function_tool_origin(func_tool),
),
)
approval_item = ToolApprovalItem(
agent=self.public_agent,
raw_item=raw_tool_call,
Expand Down Expand Up @@ -1742,6 +1772,12 @@ async def _execute_single_tool_body(
task_state.invoke_task = invoke_task
return await self._await_invoke_task(outer_task=outer_task, invoke_task=invoke_task)

def _should_run_pre_approval_tool_input_guardrails(self) -> bool:
tool_execution = self.config.tool_execution
if tool_execution is None:
return False
return tool_execution.pre_approval_tool_input_guardrails

async def _invoke_tool_and_run_post_invoke(
self,
*,
Expand Down
Loading
Loading