Skip to content

Commit 9a459ab

Browse files
committed
refactor: adapt logic to check projected tool calls against limits before execution
1 parent 3688ce7 commit 9a459ab

File tree

5 files changed

+37
-34
lines changed

5 files changed

+37
-34
lines changed

docs/agents.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,12 +630,14 @@ try:
630630
agent.run_sync('Please call the tool twice', usage_limits=UsageLimits(tool_calls_limit=1))
631631
except UsageLimitExceeded as e:
632632
print(e)
633-
#> The next tool call would exceed the tool_calls_limit of 1 (tool_calls=1)
633+
"""
634+
With the next tool call(s), the projected amount of tool calls (2) would exceed the limit of 1.
635+
"""
634636
```
635637

636638
!!! note
637639
- Usage limits are especially relevant if you've registered many tools. Use `request_limit` to bound the number of model turns, and `tool_calls_limit` to cap the number of successful tool executions within a run.
638-
- These limits are enforced at the final stage before the LLM is called. If your limits are stricter than your retry settings, the usage limit will be reached before all retries are attempted.
640+
- The `tool_calls_limit` is checked before executing tool calls. If the projected total would exceed the limit, no tools from that batch are executed.
639641

640642
#### Model (Run) Settings
641643

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -863,27 +863,22 @@ async def process_tool_calls( # noqa: C901
863863
output_final_result.append(final_result)
864864

865865

866-
def _enforce_tool_call_limits(
866+
def _check_tool_call_limits(
867867
tool_manager: ToolManager[DepsT],
868868
tool_calls: list[_messages.ToolCallPart],
869869
usage_limits: _usage.UsageLimits | None,
870-
) -> tuple[list[_messages.ToolCallPart], int]:
871-
"""Enforce tool call limits and return limited calls and extra count."""
870+
) -> None:
871+
"""Check if executing the tool calls would exceed the limit."""
872872
if usage_limits is None or usage_limits.tool_calls_limit is None:
873-
return tool_calls, 0
873+
return
874874

875875
current_tool_calls = tool_manager.ctx.usage.tool_calls if tool_manager.ctx is not None else 0
876-
remaining_allowed = usage_limits.tool_calls_limit - current_tool_calls
876+
projected_tool_calls = current_tool_calls + len(tool_calls)
877877

878-
if remaining_allowed <= 0:
879-
usage_limits.check_before_tool_call(tool_manager.ctx.usage if tool_manager.ctx else _usage.RunUsage())
880-
881-
if remaining_allowed < len(tool_calls):
882-
limited_tool_calls = tool_calls[: max(0, remaining_allowed)]
883-
extra_calls_count = len(tool_calls) - len(limited_tool_calls)
884-
return limited_tool_calls, extra_calls_count
885-
886-
return tool_calls, 0
878+
if projected_tool_calls > usage_limits.tool_calls_limit:
879+
projected_usage = deepcopy(tool_manager.ctx.usage) if tool_manager.ctx else _usage.RunUsage()
880+
projected_usage.tool_calls = projected_tool_calls
881+
usage_limits.check_before_tool_call(projected_usage)
887882

888883

889884
async def _call_tools(
@@ -899,6 +894,8 @@ async def _call_tools(
899894
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
900895
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
901896

897+
_check_tool_call_limits(tool_manager, tool_calls, usage_limits)
898+
902899
for call in tool_calls:
903900
yield _messages.FunctionToolCallEvent(call)
904901

@@ -943,8 +940,6 @@ async def handle_call_or_result(
943940
yield event
944941

945942
else:
946-
executed_calls, extra_calls_count = _enforce_tool_call_limits(tool_manager, tool_calls, usage_limits)
947-
948943
tasks = [
949944
asyncio.create_task(
950945
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
@@ -961,10 +956,6 @@ async def handle_call_or_result(
961956
if event := await handle_call_or_result(coro_or_task=task, index=index):
962957
yield event
963958

964-
# If there were extra calls beyond the allowed limit, raise now
965-
if extra_calls_count and usage_limits is not None:
966-
usage_limits.check_before_tool_call(tool_manager.ctx.usage if tool_manager.ctx else _usage.RunUsage())
967-
968959
# We append the results at the end, rather than as they are received, to retain a consistent ordering
969960
# This is mostly just to simplify testing
970961
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,12 @@ def check_tokens(self, usage: RunUsage) -> None:
340340
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
341341
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
342342

343-
def check_before_tool_call(self, usage: RunUsage) -> None:
344-
"""Raises a `UsageLimitExceeded` exception if the next tool call would exceed the tool call limit."""
343+
def check_before_tool_call(self, projected_usage: RunUsage) -> None:
344+
"""Raises a `UsageLimitExceeded` exception if the next tool call(s) would exceed the tool call limit."""
345345
tool_calls_limit = self.tool_calls_limit
346-
if tool_calls_limit is not None and usage.tool_calls >= tool_calls_limit:
346+
if tool_calls_limit is not None and projected_usage.tool_calls > tool_calls_limit:
347347
raise UsageLimitExceeded(
348-
f'The next tool call would exceed the tool_calls_limit of {tool_calls_limit} (tool_calls={usage.tool_calls})'
348+
f'With the next tool call(s), the projected amount of tool calls ({projected_usage.tool_calls}) would exceed the limit of {tool_calls_limit}.'
349349
)
350350

351351
__repr__ = _utils.dataclasses_no_defaults_repr

tests/test_examples.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,10 @@ async def call_tool(
387387
'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.'
388388
'Rome is known for its rich history, stunning architecture, and delicious cuisine.'
389389
),
390-
'Please call the tool twice': ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id'),
390+
'Please call the tool twice': [
391+
ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_1'),
392+
ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_2'),
393+
],
391394
'Begin infinite retry loop!': ToolCallPart(
392395
tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id'
393396
),

tests/test_usage_limits.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from pydantic_ai import (
1515
Agent,
16+
ModelMessage,
1617
ModelRequest,
1718
ModelResponse,
1819
RunContext,
@@ -22,6 +23,7 @@
2223
UserPromptPart,
2324
)
2425
from pydantic_ai.exceptions import ModelRetry
26+
from pydantic_ai.models.function import AgentInfo, FunctionModel
2527
from pydantic_ai.models.test import TestModel
2628
from pydantic_ai.output import ToolOutput
2729
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
@@ -254,7 +256,10 @@ async def ret_a(x: str) -> str:
254256
return f'{x}-apple'
255257

256258
with pytest.raises(
257-
UsageLimitExceeded, match=re.escape('The next tool call would exceed the tool_calls_limit of 0 (tool_calls=0)')
259+
UsageLimitExceeded,
260+
match=re.escape(
261+
'With the next tool call(s), the projected amount of tool calls (1) would exceed the limit of 0.'
262+
),
258263
):
259264
await test_agent.run('Hello', usage_limits=UsageLimits(tool_calls_limit=0))
260265

@@ -330,7 +335,7 @@ def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelR
330335
model_call_count += 1
331336

332337
if model_call_count == 1:
333-
# First response: 5 parallel tool calls
338+
# First response: 5 parallel tool calls (within limit)
334339
return ModelResponse(
335340
parts=[
336341
ToolCallPart('tool_a', {}, 'call_1'),
@@ -342,7 +347,7 @@ def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelR
342347
)
343348
else:
344349
assert model_call_count == 2
345-
# Second response: 3 parallel tool calls (should exceed limit)
350+
# Second response: 3 parallel tool calls (would exceed limit of 6)
346351
return ModelResponse(
347352
parts=[
348353
ToolCallPart('tool_c', {}, 'call_6'),
@@ -372,12 +377,14 @@ async def tool_c() -> str:
372377
executed_tools.append('c')
373378
return 'result c'
374379

375-
# Run with tool call limit of 6; expecting an error once the limit is reached
380+
# Run with tool call limit of 6; expecting an error when trying to execute 3 more tools
376381
with pytest.raises(
377382
UsageLimitExceeded,
378-
match=r'The next tool call would exceed the tool_calls_limit of 6 \(tool_calls=(6)\)',
383+
match=re.escape(
384+
'With the next tool call(s), the projected amount of tool calls (8) would exceed the limit of 6.'
385+
),
379386
):
380387
await agent.run('Use tools', usage_limits=UsageLimits(tool_calls_limit=6))
381388

382-
# Only 6 tool calls should have actually executed
383-
assert len(executed_tools) == 6
389+
# Only the first batch of 5 tools should have executed
390+
assert len(executed_tools) == 5

0 commit comments

Comments
 (0)