diff --git a/docs/agents.md b/docs/agents.md index 1a11178d65..442e4330d1 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -630,12 +630,12 @@ try: agent.run_sync('Please call the tool twice', usage_limits=UsageLimits(tool_calls_limit=1)) except UsageLimitExceeded as e: print(e) - #> The next tool call would exceed the tool_calls_limit of 1 (tool_calls=1) + #> The next tool call(s) would exceed the tool_calls_limit of 1 (tool_calls=2). ``` !!! note - Usage limits are especially relevant if you've registered many tools. Use `request_limit` to bound the number of model turns, and `tool_calls_limit` to cap the number of successful tool executions within a run. - - These limits are enforced at the final stage before the LLM is called. If your limits are stricter than your retry settings, the usage limit will be reached before all retries are attempted. + - The `tool_calls_limit` is checked before executing tool calls. If the model returns parallel tool calls that would exceed the limit, no tools will be executed. #### Model (Run) Settings diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ba14592aa8..0839dd12aa 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -230,7 +230,6 @@ async def run( # noqa: C901 # Build the run context after `ctx.deps.prompt` has been updated run_context = build_run_context(ctx) - parts: list[_messages.ModelRequestPart] = [] if messages: await self._reevaluate_dynamic_prompts(messages, run_context) @@ -821,6 +820,7 @@ async def process_tool_calls( # noqa: C901 tool_calls=calls_to_run, tool_call_results=calls_to_run_results, tracer=ctx.deps.tracer, + usage=ctx.state.usage, usage_limits=ctx.deps.usage_limits, output_parts=output_parts, output_deferred_calls=deferred_calls, @@ -867,7 +867,8 @@ async def _call_tools( tool_calls: list[_messages.ToolCallPart], tool_call_results: dict[str, DeferredToolResult], tracer: Tracer, - usage_limits: _usage.UsageLimits | None, + usage: _usage.RunUsage, + usage_limits: _usage.UsageLimits, output_parts: list[_messages.ModelRequestPart], output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]], ) -> AsyncIterator[_messages.HandleResponseEvent]: @@ -875,6 +876,11 @@ async def _call_tools( user_parts_by_index: dict[int, _messages.UserPromptPart] = {} deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {} + if usage_limits.tool_calls_limit is not None: + projected_usage = deepcopy(usage) + projected_usage.tool_calls += len(tool_calls) + usage_limits.check_before_tool_call(projected_usage) + for call in tool_calls: yield _messages.FunctionToolCallEvent(call) @@ -911,7 +917,7 @@ async def handle_call_or_result( if tool_manager.should_call_sequentially(tool_calls): for index, call in enumerate(tool_calls): if event := await handle_call_or_result( - _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits), + _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), index, ): yield event @@ -919,7 +925,7 @@ async def handle_call_or_result( else: tasks = [ asyncio.create_task( - _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits), + _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), name=call.tool_name, ) for call in tool_calls @@ -946,15 +952,14 @@ async def _call_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, tool_call_result: DeferredToolResult | None, - usage_limits: _usage.UsageLimits | None, ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]: try: if tool_call_result is None: - tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits) + tool_result = await tool_manager.handle_call(tool_call) elif isinstance(tool_call_result, ToolApproved): if tool_call_result.override_args is not None: tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args) - tool_result = await tool_manager.handle_call(tool_call, usage_limits=usage_limits) + tool_result = await tool_manager.handle_call(tool_call) elif isinstance(tool_call_result, ToolDenied): return _messages.ToolReturnPart( tool_name=tool_call.tool_name, diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 5cf66b00dd..a5546a4e01 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -18,7 +18,7 @@ from .messages import ToolCallPart from .tools import ToolDefinition from .toolsets.abstract import AbstractToolset, ToolsetTool -from .usage import UsageLimits +from .usage import RunUsage _sequential_tool_calls_ctx_var: ContextVar[bool] = ContextVar('sequential_tool_calls', default=False) @@ -93,7 +93,6 @@ async def handle_call( call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True, - usage_limits: UsageLimits | None = None, ) -> Any: """Handle a tool call by validating the arguments, calling the tool, and handling retries. @@ -108,16 +107,16 @@ async def handle_call( if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': # Output tool calls are not traced and not counted - return await self._call_tool(call, allow_partial, wrap_validation_errors, count_tool_usage=False) + return await self._call_tool(call, allow_partial, wrap_validation_errors) else: - return await self._call_tool_traced( + return await self._call_function_tool( call, allow_partial, wrap_validation_errors, self.ctx.tracer, self.ctx.trace_include_content, self.ctx.instrumentation_version, - usage_limits, + self.ctx.usage, ) async def _call_tool( @@ -125,8 +124,6 @@ async def _call_tool( call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool, - usage_limits: UsageLimits | None = None, - count_tool_usage: bool = True, ) -> Any: if self.tools is None or self.ctx is None: raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover @@ -159,14 +156,8 @@ async def _call_tool( else: args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) - if usage_limits is not None and count_tool_usage: - usage_limits.check_before_tool_call(self.ctx.usage) - result = await self.toolset.call_tool(name, args_dict, ctx, tool) - if count_tool_usage: - self.ctx.usage.tool_calls += 1 - return result except (ValidationError, ModelRetry) as e: max_retries = tool.max_retries if tool is not None else 1 @@ -199,7 +190,7 @@ async def _call_tool( raise e - async def _call_tool_traced( + async def _call_function_tool( self, call: ToolCallPart, allow_partial: bool, @@ -207,7 +198,7 @@ async def _call_tool_traced( tracer: Tracer, include_content: bool, instrumentation_version: int, - usage_limits: UsageLimits | None = None, + usage: RunUsage, ) -> Any: """See .""" instrumentation_names = InstrumentationNames.for_version(instrumentation_version) @@ -242,7 +233,9 @@ async def _call_tool_traced( attributes=span_attributes, ) as span: try: - tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits) + tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors) + usage.tool_calls += 1 + except ToolRetryError as e: part = e.tool_retry if include_content and span.is_recording(): diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 92500c2058..8eae608263 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -340,12 +340,13 @@ def check_tokens(self, usage: RunUsage) -> None: if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') - def check_before_tool_call(self, usage: RunUsage) -> None: - """Raises a `UsageLimitExceeded` exception if the next tool call would exceed the tool call limit.""" + def check_before_tool_call(self, projected_usage: RunUsage) -> None: + """Raises a `UsageLimitExceeded` exception if the next tool call(s) would exceed the tool call limit.""" tool_calls_limit = self.tool_calls_limit - if tool_calls_limit is not None and usage.tool_calls >= tool_calls_limit: + tool_calls = projected_usage.tool_calls + if tool_calls_limit is not None and tool_calls > tool_calls_limit: raise UsageLimitExceeded( - f'The next tool call would exceed the tool_calls_limit of {tool_calls_limit} (tool_calls={usage.tool_calls})' + f'The next tool call(s) would exceed the tool_calls_limit of {tool_calls_limit} ({tool_calls=}).' ) __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/tests/test_examples.py b/tests/test_examples.py index df390519d7..4c247b55cc 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -387,7 +387,10 @@ async def call_tool( 'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.' 'Rome is known for its rich history, stunning architecture, and delicious cuisine.' ), - 'Please call the tool twice': ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id'), + 'Please call the tool twice': [ + ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_1'), + ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id_2'), + ], 'Begin infinite retry loop!': ToolCallPart( tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id' ), diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index bec78c2f51..2eeeab15b6 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -1,3 +1,4 @@ +import asyncio import functools import operator import re @@ -12,6 +13,7 @@ from pydantic_ai import ( Agent, + ModelMessage, ModelRequest, ModelResponse, RunContext, @@ -21,6 +23,7 @@ UserPromptPart, ) from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import ToolOutput from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits @@ -253,7 +256,8 @@ async def ret_a(x: str) -> str: return f'{x}-apple' with pytest.raises( - UsageLimitExceeded, match=re.escape('The next tool call would exceed the tool_calls_limit of 0 (tool_calls=0)') + UsageLimitExceeded, + match=re.escape('The next tool call(s) would exceed the tool_calls_limit of 0 (tool_calls=1).'), ): await test_agent.run('Hello', usage_limits=UsageLimits(tool_calls_limit=0)) @@ -286,8 +290,42 @@ async def another_regular_tool(x: str) -> str: assert result_output.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=15, tool_calls=1)) +async def test_output_tool_allowed_at_limit() -> None: + """Test that output tools can be called even when at the tool_calls_limit.""" + + class MyOutput(BaseModel): + result: str + + def call_output_after_regular(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('regular_tool', {'x': 'test'}, 'call_1'), + ], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + else: + return ModelResponse( + parts=[ + ToolCallPart('final_result', {'result': 'success'}, 'call_2'), + ], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + test_agent = Agent(FunctionModel(call_output_after_regular), output_type=ToolOutput(MyOutput)) + + @test_agent.tool_plain + async def regular_tool(x: str) -> str: + return f'{x}-processed' + + result = await test_agent.run('test', usage_limits=UsageLimits(tool_calls_limit=1)) + + assert result.output.result == 'success' + assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=20, output_tokens=10, tool_calls=1)) + + async def test_failed_tool_calls_not_counted() -> None: - """Test that failed tool calls (raising ModelRetry) are not counted.""" + """Test that failed tool calls (raising ModelRetry) are not counted in usage or against limits.""" test_agent = Agent(TestModel()) call_count = 0 @@ -300,8 +338,7 @@ async def flaky_tool(x: str) -> str: raise ModelRetry('Temporary failure, please retry') return f'{x}-success' - result = await test_agent.run('test') - # The tool was called twice (1 failure + 1 success), but only the successful call should be counted + result = await test_agent.run('test', usage_limits=UsageLimits(tool_calls_limit=1)) assert call_count == 2 assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=176, output_tokens=29, tool_calls=1)) @@ -316,3 +353,67 @@ def test_deprecated_usage_limits(): snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead']) ): assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore + + +async def test_parallel_tool_calls_limit_enforced(): + """Parallel tool calls must not exceed the limit and should raise immediately.""" + executed_tools: list[str] = [] + + model_call_count = 0 + + def test_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal model_call_count + model_call_count += 1 + + if model_call_count == 1: + # First response: 5 parallel tool calls (within limit) + return ModelResponse( + parts=[ + ToolCallPart('tool_a', {}, 'call_1'), + ToolCallPart('tool_b', {}, 'call_2'), + ToolCallPart('tool_c', {}, 'call_3'), + ToolCallPart('tool_a', {}, 'call_4'), + ToolCallPart('tool_b', {}, 'call_5'), + ] + ) + else: + assert model_call_count == 2 + # Second response: 3 parallel tool calls (would exceed limit of 6) + return ModelResponse( + parts=[ + ToolCallPart('tool_c', {}, 'call_6'), + ToolCallPart('tool_a', {}, 'call_7'), + ToolCallPart('tool_b', {}, 'call_8'), + ] + ) + + test_model = FunctionModel(test_model_function) + agent = Agent(test_model) + + @agent.tool_plain + async def tool_a() -> str: + await asyncio.sleep(0.01) + executed_tools.append('a') + return 'result a' + + @agent.tool_plain + async def tool_b() -> str: + await asyncio.sleep(0.01) + executed_tools.append('b') + return 'result b' + + @agent.tool_plain + async def tool_c() -> str: + await asyncio.sleep(0.01) + executed_tools.append('c') + return 'result c' + + # Run with tool call limit of 6; expecting an error when trying to execute 3 more tools + with pytest.raises( + UsageLimitExceeded, + match=re.escape('The next tool call(s) would exceed the tool_calls_limit of 6 (tool_calls=8).'), + ): + await agent.run('Use tools', usage_limits=UsageLimits(tool_calls_limit=6)) + + # Only the first batch of 5 tools should have executed + assert len(executed_tools) == 5