From 12acfa4abca8fbc9b4ae08da360d227ec56411fc Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 22 May 2026 09:12:48 +0900 Subject: [PATCH 1/2] feat: add SDK-only custom data for tool outputs --- src/agents/__init__.py | 16 ++ src/agents/items.py | 7 + src/agents/mcp/__init__.py | 4 + src/agents/mcp/server.py | 21 ++ src/agents/mcp/util.py | 89 +++++++- src/agents/run_internal/tool_actions.py | 115 +++++++--- src/agents/run_internal/tool_execution.py | 17 ++ src/agents/run_state.py | 13 +- src/agents/tool.py | 129 ++++++++++++ src/agents/tool_context.py | 2 + src/agents/util/_custom_data.py | 57 +++++ tests/mcp/helpers.py | 6 +- tests/test_run_state.py | 32 +++ tests/test_tool_custom_data.py | 246 ++++++++++++++++++++++ 14 files changed, 722 insertions(+), 32 deletions(-) create mode 100644 src/agents/util/_custom_data.py create mode 100644 tests/test_tool_custom_data.py diff --git a/src/agents/__init__.py b/src/agents/__init__.py index e3bc96abff..8687585760 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -130,12 +130,20 @@ ) from .tool import ( ApplyPatchTool, + ApplyPatchToolCustomDataContext, + ApplyPatchToolCustomDataExtractor, CodeInterpreterTool, ComputerProvider, ComputerTool, + ComputerToolCustomDataContext, + ComputerToolCustomDataExtractor, CustomTool, + CustomToolCustomDataContext, + CustomToolCustomDataExtractor, FileSearchTool, FunctionTool, + FunctionToolCustomDataContext, + FunctionToolCustomDataExtractor, FunctionToolResult, HostedMCPTool, ImageGenerationTool, @@ -446,10 +454,16 @@ def enable_verbose_stdout_logging(): "AgentUpdatedStreamEvent", "StreamEvent", "FunctionTool", + "FunctionToolCustomDataContext", + "FunctionToolCustomDataExtractor", "FunctionToolResult", "ComputerTool", + "ComputerToolCustomDataContext", + "ComputerToolCustomDataExtractor", "ComputerProvider", "CustomTool", + "CustomToolCustomDataContext", + "CustomToolCustomDataExtractor", "FileSearchTool", "CodeInterpreterTool", "ImageGenerationTool", @@ -482,6 +496,8 @@ def enable_verbose_stdout_logging(): "ApplyPatchOperation", "ApplyPatchResult", "ApplyPatchTool", + "ApplyPatchToolCustomDataContext", + "ApplyPatchToolCustomDataExtractor", "Tool", "WebSearchTool", "HostedMCPTool", diff --git a/src/agents/items.py b/src/agents/items.py index c761cc221f..f5545c7b48 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -403,6 +403,13 @@ class ToolCallOutputItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = None """Optional metadata describing the source of a function-tool-backed item.""" + custom_data: dict[str, Any] | None = None + """SDK-only custom data attached to this tool output. + + This data is not part of ``raw_item`` and is not sent back to the model when the output item is + replayed as input. + """ + @property def call_id(self) -> str | None: """Return the call identifier from the raw item, if available.""" diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index f0de5bda66..e83f2fe209 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -17,6 +17,8 @@ ) from .util import ( + MCPToolCustomDataContext, + MCPToolCustomDataExtractor, MCPToolMetaContext, MCPToolMetaResolver, MCPUtil, @@ -50,6 +52,8 @@ "MCPServerManager", "LocalMCPApprovalCallable", "MCPUtil", + "MCPToolCustomDataContext", + "MCPToolCustomDataExtractor", "MCPToolMetaContext", "MCPToolMetaResolver", "ToolFilter", diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 8d3bdd752a..2686426b88 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -45,6 +45,7 @@ from ..util._types import MaybeAwaitable from .util import ( HttpClientFactory, + MCPToolCustomDataExtractor, MCPToolMetaResolver, ToolFilter, ToolFilterContext, @@ -229,6 +230,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + custom_data_extractor: MCPToolCustomDataExtractor | None = None, ): """ Args: @@ -248,6 +250,8 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + custom_data_extractor: Optional callable that produces SDK-only custom data for + emitted MCP tool output items. """ self.use_structured_content = use_structured_content self._needs_approval_policy = self._normalize_needs_approval( @@ -255,6 +259,7 @@ def __init__( ) self._failure_error_function = failure_error_function self.tool_meta_resolver = tool_meta_resolver + self.custom_data_extractor = custom_data_extractor @abc.abstractmethod async def connect(self): @@ -544,6 +549,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + custom_data_extractor: MCPToolCustomDataExtractor | None = None, ): """ Args: @@ -576,12 +582,15 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + custom_data_extractor: Optional callable that produces SDK-only custom data for + emitted MCP tool output items. """ super().__init__( use_structured_content=use_structured_content, require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + custom_data_extractor=custom_data_extractor, ) self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -1108,6 +1117,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + custom_data_extractor: MCPToolCustomDataExtractor | None = None, ): """Create a new MCP server based on the stdio transport. @@ -1145,6 +1155,8 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + custom_data_extractor: Optional callable that produces SDK-only custom data for + emitted MCP tool output items. """ super().__init__( cache_tools_list=cache_tools_list, @@ -1157,6 +1169,7 @@ def __init__( require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + custom_data_extractor=custom_data_extractor, ) self.params = StdioServerParameters( @@ -1229,6 +1242,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + custom_data_extractor: MCPToolCustomDataExtractor | None = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -1268,6 +1282,8 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + custom_data_extractor: Optional callable that produces SDK-only custom data for + emitted MCP tool output items. """ super().__init__( cache_tools_list=cache_tools_list, @@ -1280,6 +1296,7 @@ def __init__( require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + custom_data_extractor=custom_data_extractor, ) self.params = params @@ -1365,6 +1382,7 @@ def __init__( require_approval: RequireApprovalSetting = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + custom_data_extractor: MCPToolCustomDataExtractor | None = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -1405,6 +1423,8 @@ def __init__( SDK default) will be used. tool_meta_resolver: Optional callable that produces MCP request metadata (`_meta`) for tool calls. It is invoked by the Agents SDK before calling `call_tool`. + custom_data_extractor: Optional callable that produces SDK-only custom data for + emitted MCP tool output items. """ super().__init__( cache_tools_list=cache_tools_list, @@ -1417,6 +1437,7 @@ def __init__( require_approval=require_approval, failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + custom_data_extractor=custom_data_extractor, ) self.params = params diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index bf00cb2b79..4ec3a3bac0 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -7,8 +7,9 @@ import inspect import json from collections import Counter -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass +from types import MappingProxyType from typing import TYPE_CHECKING, Any, Protocol, Union import httpx @@ -39,6 +40,7 @@ ) from ..tool_context import ToolContext from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span +from ..util._custom_data import maybe_extract_custom_data from ..util._types import MaybeAwaitable if TYPE_CHECKING: @@ -149,13 +151,50 @@ class MCPToolMetaContext: """The parsed tool arguments.""" +@dataclass(frozen=True) +class MCPToolCustomDataContext: + """Context passed to MCP tool custom data extractors.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + server_name: str + """The name of the MCP server.""" + + tool_name: str + """The original MCP tool name invoked on the server.""" + + tool_display_name: str + """The public tool name exposed through the Agents SDK.""" + + arguments: Mapping[str, Any] + """The parsed tool arguments.""" + + result_meta: Mapping[str, Any] | None + """The MCP tool result ``_meta`` payload, if present.""" + + structured_content: Mapping[str, Any] | None + """The MCP tool result ``structuredContent`` payload, if present.""" + + is_error: bool | None + """The MCP tool result ``isError`` flag, if present.""" + + tool_output: ToolOutput + """The model-visible tool output produced by the Agents SDK.""" + + if TYPE_CHECKING: MCPToolMetaResolver = Callable[ [MCPToolMetaContext], MaybeAwaitable[dict[str, Any] | None], ] + MCPToolCustomDataExtractor = Callable[ + [MCPToolCustomDataContext], + MaybeAwaitable[Mapping[str, Any] | None], + ] else: MCPToolMetaResolver = Callable[..., Any] + MCPToolCustomDataExtractor = Callable[..., Any] """A function that produces MCP request metadata for tool calls. Args: @@ -164,6 +203,7 @@ class MCPToolMetaContext: Returns: A dict to send as MCP `_meta`, or None to omit metadata. """ +"""A function that produces SDK-only custom data for MCP tool output items.""" def create_static_tool_filter( @@ -541,6 +581,41 @@ def _merge_mcp_meta( merged.update(copy.deepcopy(explicit_meta)) return merged + @staticmethod + def _copy_mapping_proxy(value: Any) -> Mapping[str, Any] | None: + if not isinstance(value, dict): + return None + return MappingProxyType(copy.deepcopy(value)) + + @classmethod + async def _extract_custom_data( + cls, + *, + server: MCPServer, + context: RunContextWrapper[Any], + tool_name: str, + tool_display_name: str, + arguments: dict[str, Any], + result: Any, + tool_output: ToolOutput, + ) -> dict[str, Any] | None: + extractor = getattr(server, "custom_data_extractor", None) + if extractor is None: + return None + + extractor_context = MCPToolCustomDataContext( + run_context=context, + server_name=server.name, + tool_name=tool_name, + tool_display_name=tool_display_name, + arguments=MappingProxyType(copy.deepcopy(arguments)), + result_meta=cls._copy_mapping_proxy(getattr(result, "meta", None)), + structured_content=cls._copy_mapping_proxy(getattr(result, "structuredContent", None)), + is_error=getattr(result, "isError", None), + tool_output=copy.deepcopy(tool_output), + ) + return await maybe_extract_custom_data(extractor, extractor_context) + @classmethod async def _resolve_meta( cls, @@ -688,6 +763,18 @@ async def invoke_mcp_tool( else: tool_output = tool_output_list + custom_data = await cls._extract_custom_data( + server=server, + context=context, + tool_name=tool.name, + tool_display_name=tool_name_for_display, + arguments=json_data, + result=result, + tool_output=tool_output, + ) + if custom_data and isinstance(context, ToolContext): + context._custom_data = custom_data + current_span = get_current_span() if current_span: if isinstance(current_span.span_data, FunctionSpanData): diff --git a/src/agents/run_internal/tool_actions.py b/src/agents/run_internal/tool_actions.py index 310fdc2592..ac3a838312 100644 --- a/src/agents/run_internal/tool_actions.py +++ b/src/agents/run_internal/tool_actions.py @@ -26,7 +26,10 @@ from ..run_context import RunContextWrapper from ..tool import ( ApplyPatchTool, + ApplyPatchToolCustomDataContext, + ComputerToolCustomDataContext, CustomTool, + CustomToolCustomDataContext, LocalShellCommandRequest, ShellCommandRequest, ShellResult, @@ -36,6 +39,7 @@ from ..tracing import SpanError from ..util import _coro from ..util._approvals import evaluate_needs_approval_setting +from ..util._custom_data import maybe_extract_custom_data from .items import apply_patch_rejection_item, shell_rejection_item from .tool_execution import ( coerce_apply_patch_operations, @@ -150,6 +154,27 @@ async def _run_action(span: Any | None) -> RunItem: logger.error("Failed to execute computer action: %s", exc, exc_info=True) raise + image_url = f"data:image/png;base64,{output}" if output else "" + raw_item = ComputerCallOutput( + call_id=action.tool_call.call_id, + output={ + "type": "computer_screenshot", + "image_url": image_url, + }, + type="computer_call_output", + acknowledged_safety_checks=acknowledged_safety_checks, + ) + custom_data = await maybe_extract_custom_data( + action.computer_tool.custom_data_extractor, + ComputerToolCustomDataContext( + run_context=context_wrapper, + tool=action.computer_tool, + tool_call=action.tool_call, + output=image_url, + raw_item=raw_item, + ), + ) + await asyncio.gather( hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), ( @@ -159,22 +184,14 @@ async def _run_action(span: Any | None) -> RunItem: ), ) - image_url = f"data:image/png;base64,{output}" if output else "" if span and config.trace_include_sensitive_data: span.span_data.output = image_url return ToolCallOutputItem( agent=agent, output=image_url, - raw_item=ComputerCallOutput( - call_id=action.tool_call.call_id, - output={ - "type": "computer_screenshot", - "image_url": image_url, - }, - type="computer_call_output", - acknowledged_safety_checks=acknowledged_safety_checks, - ), + raw_item=raw_item, + custom_data=custom_data, ) return await with_tool_function_span( @@ -686,6 +703,18 @@ async def _run_call(span: Any | None) -> RunItem: ) logger.error("Custom tool failed: %s", exc, exc_info=True) + raw_item = cls._raw_tool_output_item(call_id, output_text) + custom_data = await maybe_extract_custom_data( + custom_tool.custom_data_extractor, + CustomToolCustomDataContext( + tool_context=tool_context, + tool=custom_tool, + input=tool_input, + output=output_text, + raw_item=raw_item, + ), + ) + await asyncio.gather( hooks.on_tool_end(tool_context, agent, custom_tool, output_text), ( @@ -697,8 +726,13 @@ async def _run_call(span: Any | None) -> RunItem: if span and config.trace_include_sensitive_data: span.span_data.output = output_text - - return cls._tool_output_item(agent, call_id, output_text) + return cls._tool_output_item( + agent, + call_id, + output_text, + raw_item=raw_item, + custom_data=custom_data, + ) return await with_tool_function_span( config=config, @@ -711,18 +745,28 @@ def _normalize_output(output: Any) -> str: return output if isinstance(output, str) else str(output) @staticmethod - def _tool_output_item(agent: Agent[Any], call_id: str, output: str) -> ToolCallOutputItem: + def _raw_tool_output_item(call_id: str, output: str) -> dict[str, Any]: + return { + "type": "custom_tool_call_output", + "call_id": call_id, + "output": output, + } + + @classmethod + def _tool_output_item( + cls, + agent: Agent[Any], + call_id: str, + output: str, + *, + raw_item: dict[str, Any] | None = None, + custom_data: dict[str, Any] | None = None, + ) -> ToolCallOutputItem: return ToolCallOutputItem( agent=agent, output=output, - raw_item=cast( - Any, - { - "type": "custom_tool_call_output", - "call_id": call_id, - "output": output, - }, - ), + raw_item=cast(Any, raw_item or cls._raw_tool_output_item(call_id, output)), + custom_data=custom_data, ) @@ -853,6 +897,26 @@ async def _run_call(span: Any | None) -> RunItem: ) logger.error("Apply patch editor failed: %s", exc, exc_info=True) + raw_item: dict[str, Any] = { + "type": "apply_patch_call_output", + "call_id": call_id, + "status": status, + } + if output_text: + raw_item["output"] = output_text + + custom_data = await maybe_extract_custom_data( + apply_patch_tool.custom_data_extractor, + ApplyPatchToolCustomDataContext( + run_context=context_wrapper, + tool=apply_patch_tool, + operations=operations, + output=output_text, + status=status, + raw_item=raw_item, + ), + ) + await asyncio.gather( hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), ( @@ -862,14 +926,6 @@ async def _run_call(span: Any | None) -> RunItem: ), ) - raw_item: dict[str, Any] = { - "type": "apply_patch_call_output", - "call_id": call_id, - "status": status, - } - if output_text: - raw_item["output"] = output_text - if span and config.trace_include_sensitive_data: span.span_data.output = output_text @@ -877,6 +933,7 @@ async def _run_call(span: Any | None) -> RunItem: agent=agent, output=output_text, raw_item=raw_item, + custom_data=custom_data, ) return await with_tool_function_span( diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 8f30e4a01f..0ed4fa697e 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -65,6 +65,7 @@ ComputerTool, ComputerToolSafetyCheckData, FunctionTool, + FunctionToolCustomDataContext, FunctionToolResult, ShellActionRequest, ShellCallData, @@ -87,6 +88,7 @@ from ..tracing import Span, SpanError, function_span, get_current_trace from ..util import _coro, _error_tracing from ..util._approvals import evaluate_needs_approval_setting +from ..util._custom_data import maybe_extract_custom_data, merge_custom_data from ..util._tool_errors import get_trace_tool_error from ..util._types import MaybeAwaitable from ._asyncio_progress import get_function_tool_task_progress_deadline @@ -1380,6 +1382,7 @@ def __init__( self.task_states: dict[asyncio.Task[Any], _FunctionToolTaskState] = {} self.teardown_cancelled_tasks: set[asyncio.Task[Any]] = set() self.results_by_tool_run: dict[int, Any] = {} + self.custom_data_by_tool_run: dict[int, dict[str, Any]] = {} self.pending_tasks: set[asyncio.Task[Any]] = set() self.propagating_failure: BaseException | None = None self.available_function_tools: list[FunctionTool] = [] @@ -1791,6 +1794,19 @@ async def _invoke_tool_and_run_post_invoke( real_result=real_result, tool_output_guardrail_results=self.tool_output_guardrail_results, ) + raw_output_item = ItemHelpers.tool_call_output_item(tool_call, final_result) + extracted_custom_data = await maybe_extract_custom_data( + func_tool.custom_data_extractor, + FunctionToolCustomDataContext( + tool_context=tool_context, + tool=func_tool, + output=final_result, + raw_item=raw_output_item, + ), + ) + custom_data = merge_custom_data(tool_context._custom_data, extracted_custom_data) + if custom_data: + self.custom_data_by_tool_run[id(task_state.tool_run)] = custom_data await asyncio.gather( self.hooks.on_tool_end(tool_context, self.public_agent, func_tool, final_result), @@ -1898,6 +1914,7 @@ def _build_function_tool_results(self) -> list[FunctionToolResult]: raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=self.public_agent, tool_origin=get_function_tool_origin(tool_run.function_tool), + custom_data=self.custom_data_by_tool_run.get(id(tool_run)), ) else: # Skip tool output until nested interruptions are resolved. diff --git a/src/agents/run_state.py b/src/agents/run_state.py index c5bb8c9faf..10ad571976 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -128,7 +128,7 @@ # 3. to_json() always emits CURRENT_SCHEMA_VERSION. # 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer or unsupported # versions). -CURRENT_SCHEMA_VERSION = "1.10" +CURRENT_SCHEMA_VERSION = "1.11" # Keep this mapping in chronological order. Every schema bump must add a one-line summary here. SCHEMA_VERSION_SUMMARIES: dict[str, str] = { "1.0": "Initial RunState snapshot format for HITL pause/resume flows.", @@ -145,6 +145,7 @@ "1.8": "Persists SDK-generated prompt cache keys across resume flows.", "1.9": "Persists pending custom tool calls and tool origin metadata across resume flows.", "1.10": "Allows serialized RunState snapshots to disable max_turns with null.", + "1.11": "Persists SDK-only custom data on tool output items across resume flows.", } SUPPORTED_SCHEMA_VERSIONS = frozenset(SCHEMA_VERSION_SUMMARIES) @@ -908,6 +909,9 @@ def _serialize_item( tool_origin = getattr(item, "tool_origin", None) if isinstance(tool_origin, ToolOrigin): result["tool_origin"] = tool_origin.to_json_dict() + custom_data = getattr(item, "custom_data", None) + if isinstance(custom_data, dict) and custom_data: + result["custom_data"] = _ensure_json_compatible(custom_data) return result @@ -3192,12 +3196,19 @@ def _resolve_agent_info( raw_item_output = _deserialize_tool_call_output_raw_item(normalized_raw_item) if raw_item_output is None: continue + stored_custom_data = item_data.get("custom_data") + custom_data = ( + stored_custom_data + if isinstance(stored_custom_data, dict) and stored_custom_data + else None + ) result.append( ToolCallOutputItem( agent=agent, raw_item=raw_item_output, output=item_data.get("output", ""), tool_origin=_deserialize_tool_origin(item_data.get("tool_origin")), + custom_data=custom_data, ) ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 42c41397cb..07a48b44d9 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -80,6 +80,104 @@ | ToolFunctionWithToolContext[ToolParams] ) + +@dataclass(frozen=True) +class FunctionToolCustomDataContext: + """Context passed to function-tool custom data extractors.""" + + tool_context: ToolContext[Any] + """The tool invocation context.""" + + tool: FunctionTool + """The function tool that was invoked.""" + + output: Any + """The model-visible tool output.""" + + raw_item: Mapping[str, Any] + """The raw tool output item that will be replayed to the model.""" + + +@dataclass(frozen=True) +class CustomToolCustomDataContext: + """Context passed to custom-tool custom data extractors.""" + + tool_context: ToolContext[Any] + """The tool invocation context.""" + + tool: CustomTool + """The custom tool that was invoked.""" + + input: str + """The raw model-provided custom tool input.""" + + output: str + """The model-visible custom tool output.""" + + raw_item: Mapping[str, Any] + """The raw custom tool output item that will be replayed to the model.""" + + +@dataclass(frozen=True) +class ComputerToolCustomDataContext: + """Context passed to computer-tool custom data extractors.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + tool: ComputerTool[Any] + """The computer tool that was invoked.""" + + tool_call: ResponseComputerToolCall + """The computer tool call produced by the model.""" + + output: str + """The screenshot data URL returned to the model.""" + + raw_item: Any + """The raw computer call output item that will be replayed to the model.""" + + +@dataclass(frozen=True) +class ApplyPatchToolCustomDataContext: + """Context passed to apply-patch custom data extractors.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + tool: ApplyPatchTool + """The apply_patch tool that was invoked.""" + + operations: list[ApplyPatchOperation] + """The patch operations requested by the model.""" + + output: str + """The model-visible apply_patch output.""" + + status: Literal["completed", "failed"] + """The serialized apply_patch output status.""" + + raw_item: Mapping[str, Any] + """The raw apply_patch output item that will be replayed to the model.""" + + +FunctionToolCustomDataExtractor = Callable[ + [FunctionToolCustomDataContext], + MaybeAwaitable[Mapping[str, Any] | None], +] +CustomToolCustomDataExtractor = Callable[ + [CustomToolCustomDataContext], + MaybeAwaitable[Mapping[str, Any] | None], +] +ComputerToolCustomDataExtractor = Callable[ + [ComputerToolCustomDataContext], + MaybeAwaitable[Mapping[str, Any] | None], +] +ApplyPatchToolCustomDataExtractor = Callable[ + [ApplyPatchToolCustomDataContext], + MaybeAwaitable[Mapping[str, Any] | None], +] + DEFAULT_APPROVAL_REJECTION_MESSAGE = "Tool execution was not approved." ToolTimeoutBehavior = Literal["error_as_result", "raise_exception"] ToolErrorFunction = Callable[[RunContextWrapper[Any], Exception], MaybeAwaitable[str]] @@ -351,6 +449,12 @@ class FunctionTool: defer_loading: bool = False """Whether the Responses API should hide this tool definition until tool search loads it.""" + custom_data_extractor: FunctionToolCustomDataExtractor | None = field( + default=None, + kw_only=True, + ) + """Optional callback that attaches SDK-only custom data to the tool output item.""" + _failure_error_function: ToolErrorFunction | None = field( default=None, kw_only=True, @@ -511,6 +615,7 @@ def _build_wrapped_function_tool( timeout_behavior: ToolTimeoutBehavior = "error_as_result", timeout_error_function: ToolErrorFunction | None = None, defer_loading: bool = False, + custom_data_extractor: FunctionToolCustomDataExtractor | None = None, sync_invoker: bool = False, mcp_title: str | None = None, tool_origin: ToolOrigin | None = None, @@ -538,6 +643,7 @@ def _build_wrapped_function_tool( timeout_behavior=timeout_behavior, timeout_error_function=timeout_error_function, defer_loading=defer_loading, + custom_data_extractor=custom_data_extractor, _mcp_title=mcp_title, _tool_origin=tool_origin, ), @@ -615,6 +721,12 @@ class ComputerTool(Generic[ComputerT]): on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None """Optional callback to acknowledge computer tool safety checks.""" + custom_data_extractor: ComputerToolCustomDataExtractor | None = field( + default=None, + kw_only=True, + ) + """Optional callback that attaches SDK-only custom data to the tool output item.""" + def __post_init__(self) -> None: _store_computer_initializer(self) @@ -1166,6 +1278,12 @@ class ApplyPatchTool: If provided, it will be invoked immediately when an approval is needed. """ + custom_data_extractor: ApplyPatchToolCustomDataExtractor | None = field( + default=None, + kw_only=True, + ) + """Optional callback that attaches SDK-only custom data to the tool output item.""" + @property def type(self) -> str: return "apply_patch" @@ -1184,6 +1302,11 @@ class CustomTool: on_approval: CustomToolOnApprovalFunction | None = None """Optional handler to auto-approve or reject when approval is required.""" defer_loading: bool = False + custom_data_extractor: CustomToolCustomDataExtractor | None = field( + default=None, + kw_only=True, + ) + """Optional callback that attaches SDK-only custom data to the tool output item.""" tool_config: CustomToolParam = field(init=False, repr=False) @@ -1743,6 +1866,7 @@ def function_tool( timeout_behavior: ToolTimeoutBehavior = "error_as_result", timeout_error_function: ToolErrorFunction | None = None, defer_loading: bool = False, + custom_data_extractor: FunctionToolCustomDataExtractor | None = None, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -1766,6 +1890,7 @@ def function_tool( timeout_behavior: ToolTimeoutBehavior = "error_as_result", timeout_error_function: ToolErrorFunction | None = None, defer_loading: bool = False, + custom_data_extractor: FunctionToolCustomDataExtractor | None = None, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -1789,6 +1914,7 @@ def function_tool( timeout_behavior: ToolTimeoutBehavior = "error_as_result", timeout_error_function: ToolErrorFunction | None = None, defer_loading: bool = False, + custom_data_extractor: FunctionToolCustomDataExtractor | None = None, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -1834,6 +1960,8 @@ def function_tool( timeout_behavior="error_as_result". defer_loading: Whether to hide this tool definition until Responses API tool search explicitly loads it. + custom_data_extractor: Optional callback that returns SDK-only custom data to attach to + the emitted ``ToolCallOutputItem``. The returned mapping is not sent to the model. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -1904,6 +2032,7 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: timeout_behavior=timeout_behavior, timeout_error_function=timeout_error_function, defer_loading=defer_loading, + custom_data_extractor=custom_data_extractor, sync_invoker=is_sync_function_tool, ) return function_tool diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index eaad0cc167..75947630cf 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -103,6 +103,8 @@ def __init__( ) self.agent = agent self.run_config = run_config + # Internal adapter hook used to attach SDK-only custom data to the emitted output item. + self._custom_data: dict[str, Any] | None = None @property def qualified_tool_name(self) -> str: diff --git a/src/agents/util/_custom_data.py b/src/agents/util/_custom_data.py new file mode 100644 index 0000000000..4ddc140c6d --- /dev/null +++ b/src/agents/util/_custom_data.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import copy +import inspect +import json +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, TypeVar, cast + +from ..exceptions import UserError + +TContext = TypeVar("TContext") + +CustomDataExtractor = Callable[ + [TContext], Awaitable[Mapping[str, Any] | None] | Mapping[str, Any] | None +] + + +def normalize_custom_data(value: Mapping[str, Any] | None) -> dict[str, Any] | None: + """Return a JSON-compatible copy of custom tool-output data.""" + if value is None: + return None + if not isinstance(value, Mapping): + raise UserError("custom_data_extractor must return a mapping or None.") + if not value: + return None + if not all(isinstance(key, str) for key in value): + raise UserError("custom_data_extractor must return a mapping with string keys.") + + copied = copy.deepcopy(dict(value)) + try: + return cast(dict[str, Any], json.loads(json.dumps(copied))) + except (TypeError, ValueError) as exc: + raise UserError("custom_data_extractor must return JSON-compatible data.") from exc + + +async def maybe_extract_custom_data( + extractor: CustomDataExtractor[TContext] | None, + context: TContext, +) -> dict[str, Any] | None: + """Invoke a sync or async custom-data extractor and normalize its result.""" + if extractor is None: + return None + + result = extractor(context) + if inspect.isawaitable(result): + result = await result + return normalize_custom_data(result) + + +def merge_custom_data(*values: Mapping[str, Any] | None) -> dict[str, Any] | None: + """Merge optional custom-data mappings, with later mappings taking precedence.""" + merged: dict[str, Any] = {} + for value in values: + normalized = normalize_custom_data(value) + if normalized: + merged.update(normalized) + return merged or None diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index ef820fad99..59a5b9a8f9 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -20,7 +20,7 @@ from agents.mcp import MCPServer from agents.mcp.server import _UNSET, _MCPServerWithClientSession, _UnsetType -from agents.mcp.util import MCPToolMetaResolver, ToolFilter +from agents.mcp.util import MCPToolCustomDataExtractor, MCPToolMetaResolver, ToolFilter from agents.tool import ToolErrorFunction tee = shutil.which("tee") or "" @@ -76,12 +76,14 @@ def __init__( require_approval: object | None = None, failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET, tool_meta_resolver: MCPToolMetaResolver | None = None, + custom_data_extractor: MCPToolCustomDataExtractor | None = None, ): super().__init__( use_structured_content=False, require_approval=require_approval, # type: ignore[arg-type] failure_error_function=failure_error_function, tool_meta_resolver=tool_meta_resolver, + custom_data_extractor=custom_data_extractor, ) self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] @@ -90,6 +92,7 @@ def __init__( self.tool_filter = tool_filter self._server_name = server_name self._custom_content: list[Content] | None = None + self._response_meta: dict[str, Any] | None = None def add_tool(self, name: str, input_schema: dict[str, Any]): self.tools.append(MCPTool(name=name, inputSchema=input_schema)) @@ -127,6 +130,7 @@ async def call_tool( return CallToolResult( content=[TextContent(text=self.tool_results[-1], type="text")], + _meta=self._response_meta, ) async def list_prompts(self, run_context=None, agent=None) -> ListPromptsResult: diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 7b2de6b859..0911ceb43d 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -1921,6 +1921,37 @@ async def test_deserializes_custom_tool_call_output_items(self): assert restored_item.raw_item == custom_tool_output assert restored_item.output == "custom result" + async def test_deserializes_tool_call_output_custom_data(self): + """SDK-only tool output custom data should survive RunState roundtrips.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="ItemAgent") + state = make_state(agent, context=context, original_input="test", max_turns=5) + + raw_tool_output = { + "type": "function_call_output", + "call_id": "call_custom_data", + "output": "result", + } + state._generated_items.append( + ToolCallOutputItem( + agent=agent, + raw_item=raw_tool_output, + output="result", + custom_data={"ui": {"kind": "chart"}, "ids": ["a", "b"]}, + ) + ) + + json_data = state.to_json() + serialized_item = json_data["generated_items"][0] + assert serialized_item["custom_data"] == {"ui": {"kind": "chart"}, "ids": ["a", "b"]} + assert "custom_data" not in serialized_item["raw_item"] + + new_state = await RunState.from_json(agent, json_data) + + restored_item = new_state._generated_items[0] + assert isinstance(restored_item, ToolCallOutputItem) + assert restored_item.custom_data == {"ui": {"kind": "chart"}, "ids": ["a", "b"]} + async def test_serializes_original_input_with_function_call_output(self): """Test that original_input with function_call_output items is preserved.""" context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) @@ -4636,6 +4667,7 @@ def test_supported_schema_versions_match_released_boundary(self): "1.7", "1.8", "1.9", + "1.10", CURRENT_SCHEMA_VERSION, } ) diff --git a/tests/test_tool_custom_data.py b/tests/test_tool_custom_data.py new file mode 100644 index 0000000000..0043674e02 --- /dev/null +++ b/tests/test_tool_custom_data.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +from typing import Any, cast + +import pytest +from openai.types.responses import ResponseCustomToolCall +from openai.types.responses.response_computer_tool_call import ( + ActionScreenshot, + ResponseComputerToolCall, +) + +from agents import ( + Agent, + ApplyPatchTool, + Computer, + ComputerTool, + CustomTool, + RunConfig, + RunContextWrapper, + RunHooks, + Runner, + UserError, + function_tool, +) +from agents.editor import ApplyPatchOperation, ApplyPatchResult +from agents.items import ToolCallOutputItem +from agents.run_internal.run_loop import ( + ToolRunApplyPatchCall, + ToolRunComputerAction, +) +from agents.run_internal.run_steps import ToolRunCustom +from agents.run_internal.tool_actions import ( + ApplyPatchAction, + ComputerAction, + CustomToolAction, +) +from agents.tool_context import ToolContext + +from .fake_model import FakeModel +from .mcp.helpers import FakeMCPServer +from .test_apply_patch_tool import DummyApplyPatchCall +from .test_responses import get_function_tool_call, get_text_message + + +def _tool_output_items(items: list[Any]) -> list[ToolCallOutputItem]: + return [item for item in items if isinstance(item, ToolCallOutputItem)] + + +@pytest.mark.asyncio +async def test_function_tool_custom_data_is_attached_but_not_replayed() -> None: + def extract_custom_data(ctx: Any) -> dict[str, Any]: + return {"call_id": ctx.raw_item["call_id"], "output": ctx.output} + + @function_tool(custom_data_extractor=extract_custom_data) + def get_data() -> str: + return "tool_result" + + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_text_message("call tool"), get_function_tool_call("get_data", "{}")], + [get_text_message("done")], + ] + ) + agent = Agent(name="test", model=model, tools=[get_data]) + + result = await Runner.run(agent, input="user") + + tool_output = _tool_output_items(result.new_items)[0] + assert tool_output.custom_data == {"call_id": "2", "output": "tool_result"} + replay_payload = tool_output.to_input_item() + assert isinstance(replay_payload, dict) + assert "custom_data" not in replay_payload + assert all( + not (isinstance(item, dict) and "custom_data" in item) + for item in model.last_turn_args["input"] + ) + + +@pytest.mark.asyncio +async def test_function_tool_custom_data_rejects_non_json_compatible_data() -> None: + @function_tool(custom_data_extractor=lambda _ctx: {"bad": object()}) + def get_data() -> str: + return "tool_result" + + model = FakeModel() + model.add_multiple_turn_outputs( + [[get_text_message("call tool"), get_function_tool_call("get_data", "{}")]] + ) + agent = Agent(name="test", model=model, tools=[get_data]) + + with pytest.raises(UserError, match="custom_data_extractor must return JSON-compatible data"): + await Runner.run(agent, input="user") + + +@pytest.mark.asyncio +async def test_mcp_custom_data_extractor_maps_result_meta_to_tool_output_item() -> None: + def extract_custom_data(ctx: Any) -> dict[str, Any]: + return {"mcp_response_meta": dict(ctx.result_meta or {})} + + server = FakeMCPServer(custom_data_extractor=extract_custom_data) + server.add_tool("meta_tool", {}) + server._response_meta = {"chart": {"type": "line"}} + + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_text_message("call tool"), get_function_tool_call("meta_tool", "{}")], + [get_text_message("done")], + ] + ) + agent = Agent(name="test", model=model, mcp_servers=[server]) + + result = await Runner.run(agent, input="user") + + tool_output = _tool_output_items(result.new_items)[0] + assert tool_output.custom_data == {"mcp_response_meta": {"chart": {"type": "line"}}} + + +@pytest.mark.asyncio +async def test_custom_tool_custom_data_is_attached() -> None: + async def invoke(_ctx: ToolContext[Any], raw_input: str) -> str: + return raw_input.upper() + + tool = CustomTool( + name="raw_editor", + description="Edit raw text.", + on_invoke_tool=invoke, + format={"type": "text"}, + custom_data_extractor=lambda ctx: {"input": ctx.input, "output": ctx.output}, + ) + agent = Agent(name="custom-agent", tools=[tool]) + tool_call = ResponseCustomToolCall( + type="custom_tool_call", + name="raw_editor", + call_id="call_custom", + input="hello", + ) + + result = await CustomToolAction.execute( + agent=agent, + call=ToolRunCustom(tool_call=tool_call, custom_tool=tool), + hooks=RunHooks[Any](), + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert result.custom_data == {"input": "hello", "output": "HELLO"} + + +class ScreenshotComputer(Computer): + def screenshot(self) -> str: + return "base64png" + + def click(self, x: int, y: int, button: str) -> None: + pass + + def double_click(self, x: int, y: int) -> None: + pass + + def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None: + pass + + def type(self, text: str) -> None: + pass + + def wait(self) -> None: + pass + + def move(self, x: int, y: int) -> None: + pass + + def keypress(self, keys: list[str]) -> None: + pass + + def drag(self, path: list[tuple[int, int]]) -> None: + pass + + +@pytest.mark.asyncio +async def test_computer_tool_custom_data_is_attached() -> None: + computer_tool = ComputerTool( + computer=ScreenshotComputer(), + custom_data_extractor=lambda ctx: {"call_id": ctx.tool_call.call_id}, + ) + tool_call = ResponseComputerToolCall( + id="computer_1", + type="computer_call", + action=ActionScreenshot(type="screenshot"), + call_id="call_computer", + pending_safety_checks=[], + status="completed", + ) + agent = Agent(name="computer-agent", tools=[computer_tool]) + + result = await ComputerAction.execute( + agent=agent, + action=ToolRunComputerAction(tool_call=tool_call, computer_tool=computer_tool), + hooks=RunHooks[Any](), + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert result.custom_data == {"call_id": "call_computer"} + + +class RecordingEditor: + def update_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + return ApplyPatchResult(output=f"Updated {operation.path}") + + def create_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + return ApplyPatchResult(output=f"Created {operation.path}") + + def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + return ApplyPatchResult(output=f"Deleted {operation.path}") + + +@pytest.mark.asyncio +async def test_apply_patch_tool_custom_data_is_attached() -> None: + tool = ApplyPatchTool( + editor=RecordingEditor(), + custom_data_extractor=lambda ctx: { + "status": ctx.status, + "paths": [operation.path for operation in ctx.operations], + }, + ) + call = DummyApplyPatchCall( + type="apply_patch_call", + call_id="call_patch", + operation={"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"}, + ) + agent = Agent(name="patch-agent", tools=[tool]) + + result = await ApplyPatchAction.execute( + agent=agent, + call=ToolRunApplyPatchCall(tool_call=call, apply_patch_tool=tool), + hooks=RunHooks[Any](), + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert result.custom_data == {"status": "completed", "paths": ["tasks.md"]} + assert "custom_data" not in cast(dict[str, Any], result.to_input_item()) From 2952e4752f4ef60e73936ebd9b85b83dffa8ca9f Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 22 May 2026 09:41:02 +0900 Subject: [PATCH 2/2] fix review comments --- src/agents/run_internal/tool_actions.py | 7 +++-- src/agents/run_internal/tool_execution.py | 3 +- tests/test_tool_custom_data.py | 38 ++++++++++++++++++----- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/agents/run_internal/tool_actions.py b/src/agents/run_internal/tool_actions.py index ac3a838312..611cb21b66 100644 --- a/src/agents/run_internal/tool_actions.py +++ b/src/agents/run_internal/tool_actions.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +import copy import dataclasses import inspect import json @@ -171,7 +172,7 @@ async def _run_action(span: Any | None) -> RunItem: tool=action.computer_tool, tool_call=action.tool_call, output=image_url, - raw_item=raw_item, + raw_item=copy.deepcopy(raw_item), ), ) @@ -711,7 +712,7 @@ async def _run_call(span: Any | None) -> RunItem: tool=custom_tool, input=tool_input, output=output_text, - raw_item=raw_item, + raw_item=copy.deepcopy(raw_item), ), ) @@ -913,7 +914,7 @@ async def _run_call(span: Any | None) -> RunItem: operations=operations, output=output_text, status=status, - raw_item=raw_item, + raw_item=copy.deepcopy(raw_item), ), ) diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 0ed4fa697e..64fecb91d9 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +import copy import dataclasses import functools import inspect @@ -1801,7 +1802,7 @@ async def _invoke_tool_and_run_post_invoke( tool_context=tool_context, tool=func_tool, output=final_result, - raw_item=raw_output_item, + raw_item=copy.deepcopy(raw_output_item), ), ) custom_data = merge_custom_data(tool_context._custom_data, extracted_custom_data) diff --git a/tests/test_tool_custom_data.py b/tests/test_tool_custom_data.py index 0043674e02..9fa9cd1a6c 100644 --- a/tests/test_tool_custom_data.py +++ b/tests/test_tool_custom_data.py @@ -49,6 +49,7 @@ def _tool_output_items(items: list[Any]) -> list[ToolCallOutputItem]: @pytest.mark.asyncio async def test_function_tool_custom_data_is_attached_but_not_replayed() -> None: def extract_custom_data(ctx: Any) -> dict[str, Any]: + ctx.raw_item["renderer"] = "chart" return {"call_id": ctx.raw_item["call_id"], "output": ctx.output} @function_tool(custom_data_extractor=extract_custom_data) @@ -71,6 +72,8 @@ def get_data() -> str: replay_payload = tool_output.to_input_item() assert isinstance(replay_payload, dict) assert "custom_data" not in replay_payload + assert "renderer" not in replay_payload + assert "renderer" not in cast(dict[str, Any], tool_output.raw_item) assert all( not (isinstance(item, dict) and "custom_data" in item) for item in model.last_turn_args["input"] @@ -122,12 +125,16 @@ async def test_custom_tool_custom_data_is_attached() -> None: async def invoke(_ctx: ToolContext[Any], raw_input: str) -> str: return raw_input.upper() + def extract_custom_data(ctx: Any) -> dict[str, Any]: + ctx.raw_item["renderer"] = "chart" + return {"input": ctx.input, "output": ctx.output} + tool = CustomTool( name="raw_editor", description="Edit raw text.", on_invoke_tool=invoke, format={"type": "text"}, - custom_data_extractor=lambda ctx: {"input": ctx.input, "output": ctx.output}, + custom_data_extractor=extract_custom_data, ) agent = Agent(name="custom-agent", tools=[tool]) tool_call = ResponseCustomToolCall( @@ -147,6 +154,7 @@ async def invoke(_ctx: ToolContext[Any], raw_input: str) -> str: assert isinstance(result, ToolCallOutputItem) assert result.custom_data == {"input": "hello", "output": "HELLO"} + assert "renderer" not in cast(dict[str, Any], result.raw_item) class ScreenshotComputer(Computer): @@ -180,9 +188,13 @@ def drag(self, path: list[tuple[int, int]]) -> None: @pytest.mark.asyncio async def test_computer_tool_custom_data_is_attached() -> None: + def extract_custom_data(ctx: Any) -> dict[str, Any]: + ctx.raw_item["output"]["image_url"] = "mutated" + return {"call_id": ctx.tool_call.call_id} + computer_tool = ComputerTool( computer=ScreenshotComputer(), - custom_data_extractor=lambda ctx: {"call_id": ctx.tool_call.call_id}, + custom_data_extractor=extract_custom_data, ) tool_call = ResponseComputerToolCall( id="computer_1", @@ -204,6 +216,10 @@ async def test_computer_tool_custom_data_is_attached() -> None: assert isinstance(result, ToolCallOutputItem) assert result.custom_data == {"call_id": "call_computer"} + assert ( + cast(dict[str, Any], result.raw_item)["output"]["image_url"] + == "data:image/png;base64,base64png" + ) class RecordingEditor: @@ -219,12 +235,17 @@ def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: @pytest.mark.asyncio async def test_apply_patch_tool_custom_data_is_attached() -> None: - tool = ApplyPatchTool( - editor=RecordingEditor(), - custom_data_extractor=lambda ctx: { + def extract_custom_data(ctx: Any) -> dict[str, Any]: + ctx.raw_item["status"] = "failed" + ctx.raw_item["renderer"] = "patch" + return { "status": ctx.status, "paths": [operation.path for operation in ctx.operations], - }, + } + + tool = ApplyPatchTool( + editor=RecordingEditor(), + custom_data_extractor=extract_custom_data, ) call = DummyApplyPatchCall( type="apply_patch_call", @@ -243,4 +264,7 @@ async def test_apply_patch_tool_custom_data_is_attached() -> None: assert isinstance(result, ToolCallOutputItem) assert result.custom_data == {"status": "completed", "paths": ["tasks.md"]} - assert "custom_data" not in cast(dict[str, Any], result.to_input_item()) + replay_payload = cast(dict[str, Any], result.to_input_item()) + assert "custom_data" not in replay_payload + assert "renderer" not in replay_payload + assert replay_payload["status"] == "completed"