Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,12 @@ try:
agent.run_sync('Please call the tool twice', usage_limits=UsageLimits(tool_calls_limit=1))
except UsageLimitExceeded as e:
print(e)
#> The next tool call would exceed the tool_calls_limit of 1 (tool_calls=1)
#> The next tool call(s) would exceed the tool_calls_limit of 1 (tool_calls=2).
```

!!! note
- Usage limits are especially relevant if you've registered many tools. Use `request_limit` to bound the number of model turns, and `tool_calls_limit` to cap the number of successful tool executions within a run.
- These limits are enforced at the final stage before the LLM is called. If your limits are stricter than your retry settings, the usage limit will be reached before all retries are attempted.
- The `tool_calls_limit` is checked before executing tool calls. If the model returns parallel tool calls that would exceed the limit, no tools will be executed.

#### Model (Run) Settings

Expand Down
19 changes: 12 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ async def run( # noqa: C901
# Build the run context after `ctx.deps.prompt` has been updated
run_context = build_run_context(ctx)

parts: list[_messages.ModelRequestPart] = []
if messages:
await self._reevaluate_dynamic_prompts(messages, run_context)

Expand Down Expand Up @@ -821,6 +820,7 @@ async def process_tool_calls( # noqa: C901
tool_calls=calls_to_run,
tool_call_results=calls_to_run_results,
tracer=ctx.deps.tracer,
usage=ctx.state.usage,
usage_limits=ctx.deps.usage_limits,
output_parts=output_parts,
output_deferred_calls=deferred_calls,
Expand Down Expand Up @@ -867,14 +867,20 @@ async def _call_tools(
tool_calls: list[_messages.ToolCallPart],
tool_call_results: dict[str, DeferredToolResult],
tracer: Tracer,
usage_limits: _usage.UsageLimits | None,
usage: _usage.RunUsage,
usage_limits: _usage.UsageLimits,
output_parts: list[_messages.ModelRequestPart],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
) -> AsyncIterator[_messages.HandleResponseEvent]:
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}

if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
projected_usage.tool_calls += len(tool_calls)
usage_limits.check_before_tool_call(projected_usage)

for call in tool_calls:
yield _messages.FunctionToolCallEvent(call)

Expand Down Expand Up @@ -911,15 +917,15 @@ async def handle_call_or_result(
if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
yield event

else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
Expand All @@ -946,15 +952,14 @@ async def _call_tool(
tool_manager: ToolManager[DepsT],
tool_call: _messages.ToolCallPart,
tool_call_result: DeferredToolResult | None,
usage_limits: _usage.UsageLimits | None,
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
try:
if tool_call_result is None:
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
tool_result = await tool_manager.handle_call(tool_call)
elif isinstance(tool_call_result, ToolApproved):
if tool_call_result.override_args is not None:
tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args)
tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits)
tool_result = await tool_manager.handle_call(tool_call)
elif isinstance(tool_call_result, ToolDenied):
return _messages.ToolReturnPart(
tool_name=tool_call.tool_name,
Expand Down
25 changes: 9 additions & 16 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .messages import ToolCallPart
from .tools import ToolDefinition
from .toolsets.abstract import AbstractToolset, ToolsetTool
from .usage import UsageLimits
from .usage import RunUsage

_sequential_tool_calls_ctx_var: ContextVar[bool] = ContextVar('sequential_tool_calls', default=False)

Expand Down Expand Up @@ -93,7 +93,6 @@ async def handle_call(
call: ToolCallPart,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
usage_limits: UsageLimits | None = None,
) -> Any:
"""Handle a tool call by validating the arguments, calling the tool, and handling retries.

Expand All @@ -108,25 +107,23 @@ async def handle_call(

if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
# Output tool calls are not traced and not counted
return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False)
return await self._call_tool(call, allow_partial, wrap_validation_errors)
else:
return await self._call_tool_traced(
return await self._call_function_tool(
call,
allow_partial,
wrap_validation_errors,
self.ctx.tracer,
self.ctx.trace_include_content,
self.ctx.instrumentation_version,
usage_limits,
self.ctx.usage,
)

async def _call_tool(
self,
call: ToolCallPart,
allow_partial: bool,
wrap_validation_errors: bool,
usage_limits: UsageLimits | None = None,
count_tool_usage: bool = True,
) -> Any:
if self.tools is None or self.ctx is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
Expand Down Expand Up @@ -159,14 +156,8 @@ async def _call_tool(
else:
args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)

if usage_limits is not None and count_tool_usage:
usage_limits.check_before_tool_call(self.ctx.usage)

result = await self.toolset.call_tool(name, args_dict, ctx, tool)

if count_tool_usage:
self.ctx.usage.tool_calls += 1

return result
except (ValidationError, ModelRetry) as e:
max_retries = tool.max_retries if tool is not None else 1
Expand Down Expand Up @@ -199,15 +190,15 @@ async def _call_tool(

raise e

async def _call_tool_traced(
async def _call_function_tool(
self,
call: ToolCallPart,
allow_partial: bool,
wrap_validation_errors: bool,
tracer: Tracer,
include_content: bool,
instrumentation_version: int,
usage_limits: UsageLimits | None = None,
usage: RunUsage,
) -> Any:
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
instrumentation_names = InstrumentationNames.for_version(instrumentation_version)
Expand Down Expand Up @@ -242,7 +233,9 @@ async def _call_tool_traced(
attributes=span_attributes,
) as span:
try:
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits)
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
usage.tool_calls += 1

except ToolRetryError as e:
part = e.tool_retry
if include_content and span.is_recording():
Expand Down
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,13 @@ def check_tokens(self, usage: RunUsage) -> None:
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')

def check_before_tool_call(self, usage: RunUsage) -> None:
"""Raises a `UsageLimitExceeded` exception if the next tool call would exceed the tool call limit."""
def check_before_tool_call(self, projected_usage: RunUsage) -> None:
"""Raises a `UsageLimitExceeded` exception if the next tool call(s) would exceed the tool call limit."""
tool_calls_limit = self.tool_calls_limit
if tool_calls_limit is not None and usage.tool_calls >= tool_calls_limit:
tool_calls = projected_usage.tool_calls
if tool_calls_limit is not None and tool_calls > tool_calls_limit:
raise UsageLimitExceeded(
f'The next tool call would exceed the tool_calls_limit of {tool_calls_limit} (tool_calls={usage.tool_calls})'
f'The next tool call(s) would exceed the tool_calls_limit of {tool_calls_limit} ({tool_calls=}).'
)

__repr__ = _utils.dataclasses_no_defaults_repr
5 changes: 4 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,10 @@ async def call_tool(
'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.'
'Rome is known for its rich history, stunning architecture, and delicious cuisine.'
),
'Please call the tool twice': ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id'),
'Please call the tool twice': [
ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_1'),
ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_2'),
],
'Begin infinite retry loop!': ToolCallPart(
tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id'
),
Expand Down
109 changes: 105 additions & 4 deletions tests/test_usage_limits.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import functools
import operator
import re
Expand All @@ -12,6 +13,7 @@

from pydantic_ai import (
Agent,
ModelMessage,
ModelRequest,
ModelResponse,
RunContext,
Expand All @@ -21,6 +23,7 @@
UserPromptPart,
)
from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.output import ToolOutput
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
Expand Down Expand Up @@ -253,7 +256,8 @@ async def ret_a(x: str) -> str:
return f'{x}-apple'

with pytest.raises(
UsageLimitExceeded, match=re.escape('The next tool call would exceed the tool_calls_limit of 0 (tool_calls=0)')
UsageLimitExceeded,
match=re.escape('The next tool call(s) would exceed the tool_calls_limit of 0 (tool_calls=1).'),
):
await test_agent.run('Hello', usage_limits=UsageLimits(tool_calls_limit=0))

Expand Down Expand Up @@ -286,8 +290,42 @@ async def another_regular_tool(x: str) -> str:
assert result_output.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=15, tool_calls=1))


async def test_output_tool_allowed_at_limit() -> None:
"""Test that output tools can be called even when at the tool_calls_limit."""

class MyOutput(BaseModel):
result: str

def call_output_after_regular(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(
parts=[
ToolCallPart('regular_tool', {'x': 'test'}, 'call_1'),
],
usage=RequestUsage(input_tokens=10, output_tokens=5),
)
else:
return ModelResponse(
parts=[
ToolCallPart('final_result', {'result': 'success'}, 'call_2'),
],
usage=RequestUsage(input_tokens=10, output_tokens=5),
)

test_agent = Agent(FunctionModel(call_output_after_regular), output_type=ToolOutput(MyOutput))

@test_agent.tool_plain
async def regular_tool(x: str) -> str:
return f'{x}-processed'

result = await test_agent.run('test', usage_limits=UsageLimits(tool_calls_limit=1))

assert result.output.result == 'success'
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=20, output_tokens=10, tool_calls=1))


async def test_failed_tool_calls_not_counted() -> None:
"""Test that failed tool calls (raising ModelRetry) are not counted."""
"""Test that failed tool calls (raising ModelRetry) are not counted in usage or against limits."""
test_agent = Agent(TestModel())

call_count = 0
Expand All @@ -300,8 +338,7 @@ async def flaky_tool(x: str) -> str:
raise ModelRetry('Temporary failure, please retry')
return f'{x}-success'

result = await test_agent.run('test')
# The tool was called twice (1 failure + 1 success), but only the successful call should be counted
result = await test_agent.run('test', usage_limits=UsageLimits(tool_calls_limit=1))
assert call_count == 2
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=176, output_tokens=29, tool_calls=1))

Expand All @@ -316,3 +353,67 @@ def test_deprecated_usage_limits():
snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead'])
):
assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore


async def test_parallel_tool_calls_limit_enforced():
"""Parallel tool calls must not exceed the limit and should raise immediately."""
executed_tools: list[str] = []

model_call_count = 0

def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
nonlocal model_call_count
model_call_count += 1

if model_call_count == 1:
# First response: 5 parallel tool calls (within limit)
return ModelResponse(
parts=[
ToolCallPart('tool_a', {}, 'call_1'),
ToolCallPart('tool_b', {}, 'call_2'),
ToolCallPart('tool_c', {}, 'call_3'),
ToolCallPart('tool_a', {}, 'call_4'),
ToolCallPart('tool_b', {}, 'call_5'),
]
)
else:
assert model_call_count == 2
# Second response: 3 parallel tool calls (would exceed limit of 6)
return ModelResponse(
parts=[
ToolCallPart('tool_c', {}, 'call_6'),
ToolCallPart('tool_a', {}, 'call_7'),
ToolCallPart('tool_b', {}, 'call_8'),
]
)

test_model = FunctionModel(test_model_function)
agent = Agent(test_model)

@agent.tool_plain
async def tool_a() -> str:
await asyncio.sleep(0.01)
executed_tools.append('a')
return 'result a'

@agent.tool_plain
async def tool_b() -> str:
await asyncio.sleep(0.01)
executed_tools.append('b')
return 'result b'

@agent.tool_plain
async def tool_c() -> str:
await asyncio.sleep(0.01)
executed_tools.append('c')
return 'result c'

# Run with tool call limit of 6; expecting an error when trying to execute 3 more tools
with pytest.raises(
UsageLimitExceeded,
match=re.escape('The next tool call(s) would exceed the tool_calls_limit of 6 (tool_calls=8).'),
):
await agent.run('Use tools', usage_limits=UsageLimits(tool_calls_limit=6))

# Only the first batch of 5 tools should have executed
assert len(executed_tools) == 5