Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,25 @@ for tool in agent.tools:
3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, etc.
4. You can pass the decorated functions to the list of tools.

You can also decorate instance methods. Access the tool from an instance before passing it to
`Agent.tools`; the implicit `self` parameter is bound to that instance and omitted from the tool
schema.

```python
class CustomerTools:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id

@function_tool
def lookup_customer(self, customer_id: str) -> str:
"""Look up a customer by ID."""
return f"{self.tenant_id}:{customer_id}"


customer_tools = CustomerTools("tenant_123")
agent = Agent(name="Assistant", tools=[customer_tools.lookup_customer])
```

??? note "Expand to see output"

```
Expand Down
21 changes: 18 additions & 3 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class FuncSchema:
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
omitted_parameter_names: tuple[str, ...] = ()
"""Parameter names that are supplied by the SDK instead of model-generated JSON."""

def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
Expand All @@ -52,6 +54,8 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:

# Use enumerate() so we can skip the first parameter if it's context.
for idx, (name, param) in enumerate(self.signature.parameters.items()):
if name in self.omitted_parameter_names:
continue
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
if self.takes_context and idx == 0:
continue
Expand Down Expand Up @@ -228,6 +232,7 @@ def function_schema(
description_override: str | None = None,
use_docstring_info: bool = True,
strict_json_schema: bool = True,
skip_first_parameter: bool = False,
) -> FuncSchema:
"""
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
Expand All @@ -246,6 +251,8 @@ def function_schema(
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
recommend setting this to True, as it increases the likelihood of the LLM producing
correct JSON input.
skip_first_parameter: If True, omit the first signature parameter from the tool schema and
call arguments. This is used for instance methods decorated with `@function_tool`.

Returns:
A `FuncSchema` object containing the function's name, description, parameter descriptions,
Expand Down Expand Up @@ -288,22 +295,29 @@ def function_schema(
params = list(sig.parameters.items())
takes_context = False
filtered_params = []
omitted_parameter_names: list[str] = []

params_to_check = params
if skip_first_parameter and params:
omitted_parameter_names.append(params[0][0])
params_to_check = params[1:]

if params:
first_name, first_param = params[0]
if params_to_check:
first_name, first_param = params_to_check[0]
# Prefer the evaluated type hint if available
ann = type_hints.get(first_name, first_param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper or origin is ToolContext:
takes_context = True # Mark that the function takes context
omitted_parameter_names.append(first_name)
else:
filtered_params.append((first_name, first_param))
else:
filtered_params.append((first_name, first_param))

# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
for name, param in params[1:]:
for name, param in params_to_check[1:]:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
Expand Down Expand Up @@ -421,4 +435,5 @@ def function_schema(
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
omitted_parameter_names=tuple(omitted_parameter_names),
)
109 changes: 100 additions & 9 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,60 @@ class FunctionTool:
_emit_tool_origin: bool = field(default=True, kw_only=True, repr=False)
"""Whether runtime item generation should emit tool origin metadata for this tool."""

_method_tool_factory: Callable[[Any], FunctionTool] | None = field(
default=None,
kw_only=True,
repr=False,
)
"""Internal descriptor hook used for instance methods decorated with `@function_tool`."""

_staticmethod_tool_factory: Callable[[], FunctionTool] | None = field(
default=None,
kw_only=True,
repr=False,
)
"""Internal fallback for class-scoped tools wrapped in `staticmethod`."""

_method_tool_bound_to_class: bool = field(default=False, kw_only=True, repr=False)
"""Whether Python installed this tool directly on a class via `__set_name__`."""

def __set_name__(self, owner: type[Any], name: str) -> None:
if self._staticmethod_tool_factory is not None:
self._method_tool_bound_to_class = True

def __getattribute__(self, name: str) -> Any:
if not name.startswith("_") and name not in {"__class__", "__dict__"}:
object.__getattribute__(self, "_maybe_apply_staticmethod_tool")()
return object.__getattribute__(self, name)

def __get__(self, instance: Any, owner: type[Any] | None = None) -> FunctionTool:
if instance is None or self._method_tool_factory is None:
return self
return self._method_tool_factory(instance)

def _maybe_apply_staticmethod_tool(self) -> None:
try:
staticmethod_tool_factory = object.__getattribute__(self, "_staticmethod_tool_factory")
method_tool_bound_to_class = object.__getattribute__(
self, "_method_tool_bound_to_class"
)
except AttributeError:
return

if staticmethod_tool_factory is None or method_tool_bound_to_class:
return

# `staticmethod` does not forward `__set_name__` to the wrapped FunctionTool.
# Rebuild as a normal tool before exposing schema or invocation state.
object.__setattr__(self, "_staticmethod_tool_factory", None)
staticmethod_tool = staticmethod_tool_factory()
for tool_field in dataclasses.fields(FunctionTool):
object.__setattr__(self, tool_field.name, getattr(staticmethod_tool, tool_field.name))

bind_to_function_tool = getattr(self.on_invoke_tool, "__agents_bind_function_tool__", None)
if callable(bind_to_function_tool):
self.on_invoke_tool = bind_to_function_tool(self)

@property
def qualified_name(self) -> str:
"""Return the public qualified name used to identify this function tool."""
Expand Down Expand Up @@ -1836,18 +1890,43 @@ def function_tool(
explicitly loads it.
"""

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
def _is_instance_method_tool(the_func: ToolFunction[...]) -> bool:
parameters = tuple(inspect.signature(the_func).parameters.values())
if not parameters:
return False

parent_name = the_func.__qualname__.rsplit(".", 1)[0]
return "." in the_func.__qualname__ and not parent_name.endswith("<locals>")
Comment thread
adit24dhaya marked this conversation as resolved.

def _create_function_tool(
the_func: ToolFunction[...],
*,
method_tool_instance: Any | None = None,
treat_as_instance_method: bool | None = None,
enable_method_binding: bool = True,
) -> FunctionTool:
is_sync_function_tool = not inspect.iscoroutinefunction(the_func)
is_instance_method_tool = (
_is_instance_method_tool(the_func)
if treat_as_instance_method is None
else treat_as_instance_method
)
schema = function_schema(
func=the_func,
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
skip_first_parameter=is_instance_method_tool,
)

async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
if is_instance_method_tool and method_tool_instance is None:
raise UserError(
f"Instance method tool {schema.name} must be accessed from an instance"
)

tool_name = ctx.tool_name
json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input)
_log_function_tool_invocation(tool_name=tool_name, input_json=input)
Expand All @@ -1866,16 +1945,16 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
if not _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}")

leading_args: list[Any] = []
if is_instance_method_tool:
leading_args.append(method_tool_instance)
if schema.takes_context:
leading_args.append(ctx)

if not is_sync_function_tool:
if schema.takes_context:
result = await the_func(ctx, *args, **kwargs_dict)
else:
result = await the_func(*args, **kwargs_dict)
result = await the_func(*leading_args, *args, **kwargs_dict)
else:
if schema.takes_context:
result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict)
else:
result = await asyncio.to_thread(the_func, *args, **kwargs_dict)
result = await asyncio.to_thread(the_func, *leading_args, *args, **kwargs_dict)

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool {tool_name} completed.")
Expand Down Expand Up @@ -1906,6 +1985,18 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
defer_loading=defer_loading,
sync_invoker=is_sync_function_tool,
)
if enable_method_binding and is_instance_method_tool and method_tool_instance is None:
function_tool._method_tool_factory = lambda instance: _create_function_tool(
the_func,
method_tool_instance=instance,
treat_as_instance_method=True,
enable_method_binding=False,
)
function_tool._staticmethod_tool_factory = lambda: _create_function_tool(
the_func,
treat_as_instance_method=False,
enable_method_binding=False,
)
return function_tool

# If func is actually a callable, we were used as @function_tool with no parentheses
Expand Down
139 changes: 139 additions & 0 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,145 @@ async def test_simple_function():
)


@pytest.mark.asyncio
async def test_instance_method_function_tool_binds_self():
class AccountTools:
def __init__(self, prefix: str) -> None:
self.prefix = prefix

@function_tool
def lookup(self, account_id: str) -> str:
"""Look up an account."""
return f"{self.prefix}:{account_id}"

tools = AccountTools("acct")
tool = tools.lookup

assert isinstance(AccountTools.lookup, FunctionTool)
assert tool.name == "lookup"
assert "self" not in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"account_id": "123"}',
)

assert result == "acct:123"


@pytest.mark.asyncio
async def test_instance_method_function_tool_binds_non_self_receiver_name():
class AccountTools:
def __init__(self, prefix: str) -> None:
self.prefix = prefix

@function_tool
def lookup(this, account_id: str) -> str:
"""Look up an account."""
return f"{this.prefix}:{account_id}"

tools = AccountTools("acct")
tool = tools.lookup

assert "this" not in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"account_id": "123"}',
)

assert result == "acct:123"


@pytest.mark.asyncio
async def test_function_tool_does_not_treat_self_named_argument_as_method():
def lookup(self: str, account_id: str) -> str:
"""Look up an account."""
return f"{self}:{account_id}"

tool = function_tool(lookup)

assert "self" in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"self": "acct", "account_id": "123"}',
)

assert result == "acct:123"


@pytest.mark.asyncio
async def test_staticmethod_function_tool_keeps_first_parameter():
class AccountTools:
@staticmethod
@function_tool
def lookup(account_id: str) -> str:
"""Look up an account."""
return f"acct:{account_id}"

tool = AccountTools.lookup

assert isinstance(tool, FunctionTool)
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"account_id": "123"}',
)

assert result == "acct:123"


@pytest.mark.asyncio
async def test_staticmethod_function_tool_allows_self_named_parameter():
class AccountTools:
@staticmethod
@function_tool
def lookup(self: str, account_id: str) -> str:
"""Look up an account."""
return f"{self}:{account_id}"

tool = AccountTools.lookup

assert isinstance(tool, FunctionTool)
assert "self" in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"self": "acct", "account_id": "123"}',
)

assert result == "acct:123"


@pytest.mark.asyncio
async def test_instance_method_function_tool_supports_context_after_self():
class AccountTools:
@function_tool
def lookup(self, ctx: ToolContext[str], account_id: str) -> str:
"""Look up an account with context."""
return f"{ctx.context}:{account_id}"

tools = AccountTools()
tool = tools.lookup

assert "self" not in tool.params_json_schema["properties"]
assert "ctx" not in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext("tenant", tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"account_id": "123"}',
)

assert result == "tenant:123"


@pytest.mark.asyncio
async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None:
calls = {"to_thread": 0, "func": 0}
Expand Down