Skip to content
3 changes: 3 additions & 0 deletions docs/ref/tool_context.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Tool context`

::: agents.tool_context
2 changes: 1 addition & 1 deletion docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c
- `name`
- `description`
- `params_json_schema`, which is the JSON schema for the arguments
- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string.
- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string.

```python
from typing import Any
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ plugins:
- ref/lifecycle.md
- ref/items.md
- ref/run_context.md
- ref/tool_context.md
- ref/usage.md
- ref/exceptions.md
- ref/guardrail.md
Expand Down
6 changes: 5 additions & 1 deletion src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,11 @@ async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
tool_context = ToolContext.from_agent_context(
context_wrapper,
tool_call.call_id,
tool_call=tool_call,
)
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
Expand Down
7 changes: 6 additions & 1 deletion src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
)

func_tool = function_map[event.name]
tool_context = ToolContext.from_agent_context(self._context_wrapper, event.call_id)
tool_context = ToolContext(
context=self._context_wrapper.context,
usage=self._context_wrapper.usage,
tool_name=event.name,
tool_call_id=event.call_id,
)
result = await func_tool.on_invoke_tool(tool_context, event.arguments)

await self._model.send_tool_output(event, str(result), True)
Expand Down
19 changes: 16 additions & 3 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass, field, fields
from typing import Any
from typing import Any, Optional

from openai.types.responses import ResponseFunctionToolCall

from .run_context import RunContextWrapper, TContext

Expand All @@ -8,16 +10,26 @@ def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")


def _assert_must_pass_tool_name() -> str:
raise ValueError("tool_name must be passed to ToolContext")


@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""

tool_name: str = field(default_factory=_assert_must_pass_tool_name)
"""The name of the tool being invoked."""

tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

@classmethod
def from_agent_context(
cls, context: RunContextWrapper[TContext], tool_call_id: str
cls,
context: RunContextWrapper[TContext],
tool_call_id: str,
tool_call: Optional[ResponseFunctionToolCall] = None,
) -> "ToolContext":
"""
Create a ToolContext from a RunContextWrapper.
Expand All @@ -26,4 +38,5 @@ def from_agent_context(
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
return cls(tool_call_id=tool_call_id, **base_values)
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
49 changes: 34 additions & 15 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ async def test_argless_function():
tool = function_tool(argless_function)
assert tool.name == "argless_function"

result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
result = await tool.on_invoke_tool(
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
)
assert result == "ok"


Expand All @@ -39,11 +41,13 @@ async def test_argless_with_context():
tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
assert result == "ok"

# Extra JSON should not raise an error
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
)
assert result == "ok"


Expand All @@ -56,15 +60,19 @@ async def test_simple_function():
tool = function_tool(simple_function, failure_error_function=None)
assert tool.name == "simple_function"

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
)
assert result == 6

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
)
assert result == 3

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")


class Foo(BaseModel):
Expand Down Expand Up @@ -92,7 +100,9 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "6 hello10 hello"

valid_json = json.dumps(
Expand All @@ -101,7 +111,9 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "3 hello10 hello"

valid_json = json.dumps(
Expand All @@ -111,12 +123,16 @@ async def test_complex_args_function():
"baz": "world",
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "3 hello10 world"

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
)


def test_function_config_overrides():
Expand Down Expand Up @@ -176,7 +192,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert tool.params_json_schema[key] == value
assert tool.strict_json_schema

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
)
assert result == "hello_done"

tool_not_strict = FunctionTool(
Expand All @@ -191,7 +209,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert "additionalProperties" not in tool_not_strict.params_json_schema

result = await tool_not_strict.on_invoke_tool(
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
'{"data": "hello", "bar": "baz"}',
)
assert result == "hello_done"

Expand All @@ -202,7 +221,7 @@ def my_func(a: int, b: int = 5):
raise ValueError("test")

tool = function_tool(my_func)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert "Invalid JSON" in str(result)
Expand All @@ -226,7 +245,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand All @@ -250,7 +269,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):


def ctx_wrapper() -> ToolContext[DummyContext]:
return ToolContext(context=DummyContext(), tool_call_id="1")
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")


@function_tool
Expand Down