From 95b543e379e49d1c95385981c003680f854b0a5f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 30 Jul 2025 19:13:19 +0100 Subject: [PATCH 01/71] new Usage interface to work with genai-prices --- docs/agents.md | 14 +- docs/direct.md | 4 +- docs/message-history.md | 12 +- docs/models/index.md | 2 +- docs/models/openai.md | 4 +- docs/multi-agent-applications.md | 4 +- docs/output.md | 2 +- docs/testing.md | 8 +- docs/tools.md | 6 +- .../pydantic_ai_examples/flight_booking.py | 6 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 +- pydantic_ai_slim/pydantic_ai/_run_context.py | 4 +- pydantic_ai_slim/pydantic_ai/ag_ui.py | 6 +- pydantic_ai_slim/pydantic_ai/agent.py | 58 ++-- pydantic_ai_slim/pydantic_ai/direct.py | 8 +- pydantic_ai_slim/pydantic_ai/messages.py | 40 ++- .../pydantic_ai/models/__init__.py | 6 +- .../pydantic_ai/models/anthropic.py | 14 +- .../pydantic_ai/models/bedrock.py | 16 +- pydantic_ai_slim/pydantic_ai/models/cohere.py | 10 +- .../pydantic_ai/models/function.py | 18 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 18 +- pydantic_ai_slim/pydantic_ai/models/google.py | 20 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 16 +- .../pydantic_ai/models/huggingface.py | 12 +- .../pydantic_ai/models/mcp_sampling.py | 2 +- .../pydantic_ai/models/mistral.py | 18 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 41 ++- pydantic_ai_slim/pydantic_ai/models/test.py | 6 +- pydantic_ai_slim/pydantic_ai/result.py | 10 +- pydantic_ai_slim/pydantic_ai/usage.py | 271 ++++++++++++++---- pydantic_ai_slim/pyproject.toml | 9 +- tests/models/test_anthropic.py | 206 ++++++------- tests/models/test_bedrock.py | 24 +- tests/models/test_cohere.py | 62 ++-- tests/models/test_deepseek.py | 10 +- tests/models/test_fallback.py | 26 +- tests/models/test_gemini.py | 218 +++++++------- tests/models/test_gemini_vertex.py | 14 +- tests/models/test_google.py | 178 ++++++------ tests/models/test_groq.py | 56 ++-- tests/models/test_huggingface.py | 58 ++-- tests/models/test_instrumented.py | 6 +- tests/models/test_mcp_sampling.py | 10 +- tests/models/test_mistral.py | 120 ++++---- tests/models/test_model_function.py | 28 +- tests/models/test_model_test.py | 10 +- tests/models/test_openai.py | 248 ++++++++-------- tests/models/test_openai_responses.py | 146 +++++----- tests/test_a2a.py | 4 +- tests/test_agent.py | 138 ++++----- tests/test_direct.py | 8 +- tests/test_history_processor.py | 6 +- tests/test_mcp.py | 220 +++++++------- tests/test_streaming.py | 62 ++-- tests/test_tools.py | 6 +- tests/test_toolsets.py | 4 +- tests/test_usage_limits.py | 38 ++- uv.lock | 16 ++ 59 files changed, 1410 insertions(+), 1181 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 91e2602e3e..d239c88887 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -142,7 +142,7 @@ async def main(): model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( - requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + requests=1, input_tokens=56, output_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), @@ -206,8 +206,8 @@ async def main(): parts=[TextPart(content='Paris')], usage=Usage( requests=1, - request_tokens=56, - response_tokens=1, + input_tokens=56, + output_tokens=1, total_tokens=57, ), model_name='gpt-4o', @@ -401,7 +401,7 @@ result_sync = agent.run_sync( print(result_sync.output) #> Rome print(result_sync.usage()) -#> Usage(requests=1, request_tokens=62, response_tokens=1, total_tokens=63) +#> Usage(requests=1, input_tokens=62, output_tokens=1, total_tokens=63) try: result_sync = agent.run_sync( @@ -410,7 +410,7 @@ try: ) except UsageLimitExceeded as e: print(e) - #> Exceeded the response_tokens_limit of 10 (response_tokens=32) + #> Exceeded the response_tokens_limit of 10 (output_tokens=32) ``` Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling: @@ -849,7 +849,7 @@ with capture_run_messages() as messages: # (2)! ) ], usage=Usage( - requests=1, request_tokens=62, response_tokens=4, total_tokens=66 + requests=1, input_tokens=62, output_tokens=4, total_tokens=66 ), model_name='gpt-4o', timestamp=datetime.datetime(...), @@ -873,7 +873,7 @@ with capture_run_messages() as messages: # (2)! ) ], usage=Usage( - requests=1, request_tokens=72, response_tokens=8, total_tokens=80 + requests=1, input_tokens=72, output_tokens=8, total_tokens=80 ), model_name='gpt-4o', timestamp=datetime.datetime(...), diff --git a/docs/direct.md b/docs/direct.md index b6c7179349..d43e31dc11 100644 --- a/docs/direct.md +++ b/docs/direct.md @@ -28,7 +28,7 @@ model_response = model_request_sync( print(model_response.parts[0].content) #> Paris print(model_response.usage) -#> Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57) +#> Usage(requests=1, input_tokens=56, output_tokens=1, total_tokens=57) ``` _(This example is complete, it can be run "as is")_ @@ -83,7 +83,7 @@ async def main(): tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b', ) ], - usage=Usage(requests=1, request_tokens=55, response_tokens=7, total_tokens=62), + usage=Usage(requests=1, input_tokens=55, output_tokens=7, total_tokens=62), model_name='gpt-4.1-nano', timestamp=datetime.datetime(...), ) diff --git a/docs/message-history.md b/docs/message-history.md index 3e8cedbd03..23615559dc 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -58,7 +58,7 @@ print(result.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), + usage=Usage(requests=1, input_tokens=60, output_tokens=12, total_tokens=72), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -126,7 +126,7 @@ async def main(): content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(request_tokens=50, response_tokens=12, total_tokens=62), + usage=Usage(input_tokens=50, output_tokens=12, total_tokens=62), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -180,7 +180,7 @@ print(result2.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), + usage=Usage(requests=1, input_tokens=60, output_tokens=12, total_tokens=72), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -198,7 +198,7 @@ print(result2.all_messages()) content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], - usage=Usage(requests=1, request_tokens=61, response_tokens=26, total_tokens=87), + usage=Usage(requests=1, input_tokens=61, output_tokens=26, total_tokens=87), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -299,7 +299,7 @@ print(result2.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), + usage=Usage(requests=1, input_tokens=60, output_tokens=12, total_tokens=72), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -317,7 +317,7 @@ print(result2.all_messages()) content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], - usage=Usage(requests=1, request_tokens=61, response_tokens=26, total_tokens=87), + usage=Usage(requests=1, input_tokens=61, output_tokens=26, total_tokens=87), model_name='gemini-1.5-pro', timestamp=datetime.datetime(...), ), diff --git a/docs/models/index.md b/docs/models/index.md index 7c5e871599..24d28ae30d 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -117,7 +117,7 @@ print(response.all_messages()) model_name='claude-3-5-sonnet-latest', timestamp=datetime.datetime(...), kind='response', - vendor_id=None, + provider_request_id=None, ), ] """ diff --git a/docs/models/openai.md b/docs/models/openai.md index 5dd131068d..9d6580e67b 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -272,7 +272,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) +#> Usage(requests=1, input_tokens=57, output_tokens=8, total_tokens=65) ``` #### Example using a remote server @@ -301,7 +301,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) +#> Usage(requests=1, input_tokens=57, output_tokens=8, total_tokens=65) ``` 1. The name of the model running on the remote server diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index a97cc98132..761d339dbc 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -53,7 +53,7 @@ result = joke_selection_agent.run_sync( print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) -#> Usage(requests=3, request_tokens=204, response_tokens=24, total_tokens=228) +#> Usage(requests=3, input_tokens=204, output_tokens=24, total_tokens=228) ``` 1. The "parent" or controlling agent. @@ -144,7 +144,7 @@ async def main(): print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! - #> Usage(requests=4, request_tokens=309, response_tokens=32, total_tokens=341) + #> Usage(requests=4, input_tokens=309, output_tokens=32, total_tokens=341) ``` 1. Define a dataclass to hold the client and API key dependencies. diff --git a/docs/output.md b/docs/output.md index 7e1eb16ede..a7676aae7e 100644 --- a/docs/output.md +++ b/docs/output.md @@ -24,7 +24,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) +#> Usage(requests=1, input_tokens=57, output_tokens=8, total_tokens=65) ``` _(This example is complete, it can be run "as is")_ diff --git a/docs/testing.md b/docs/testing.md index 49b3eba3ca..68330671b8 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -143,8 +143,8 @@ async def test_forecast(): ], usage=Usage( requests=1, - request_tokens=71, - response_tokens=7, + input_tokens=71, + output_tokens=7, total_tokens=78, details=None, ), @@ -169,8 +169,8 @@ async def test_forecast(): ], usage=Usage( requests=1, - request_tokens=77, - response_tokens=16, + input_tokens=77, + output_tokens=16, total_tokens=93, details=None, ), diff --git a/docs/tools.md b/docs/tools.md index 4b40e78818..84b3e64ef0 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -95,7 +95,7 @@ print(dice_result.all_messages()) tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], - usage=Usage(requests=1, request_tokens=90, response_tokens=2, total_tokens=92), + usage=Usage(requests=1, input_tokens=90, output_tokens=2, total_tokens=92), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), @@ -115,7 +115,7 @@ print(dice_result.all_messages()) tool_name='get_player_name', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], - usage=Usage(requests=1, request_tokens=91, response_tokens=4, total_tokens=95), + usage=Usage(requests=1, input_tokens=91, output_tokens=4, total_tokens=95), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), @@ -136,7 +136,7 @@ print(dice_result.all_messages()) ) ], usage=Usage( - requests=1, request_tokens=92, response_tokens=12, total_tokens=104 + requests=1, input_tokens=92, output_tokens=12, total_tokens=104 ), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), diff --git a/examples/pydantic_ai_examples/flight_booking.py b/examples/pydantic_ai_examples/flight_booking.py index 5029c2038d..e7776340e3 100644 --- a/examples/pydantic_ai_examples/flight_booking.py +++ b/examples/pydantic_ai_examples/flight_booking.py @@ -13,7 +13,7 @@ from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.messages import ModelMessage -from pydantic_ai.usage import Usage, UsageLimits +from pydantic_ai.usage import RunUsage, UsageLimits # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') @@ -182,7 +182,7 @@ async def main(): req_date=datetime.date(2025, 1, 10), ) message_history: list[ModelMessage] | None = None - usage: Usage = Usage() + usage: RunUsage = RunUsage() # run the agent until a satisfactory flight is found while True: result = await search_agent.run( @@ -213,7 +213,7 @@ async def main(): ) -async def find_seat(usage: Usage) -> SeatPreference: +async def find_seat(usage: RunUsage) -> SeatPreference: message_history: list[ModelMessage] | None = None while True: answer = Prompt.ask('What seat would you like?') diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 12e6e07fe8..51257cc46b 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -75,7 +75,7 @@ class GraphAgentState: """State kept across the execution of the agent graph.""" message_history: list[_messages.ModelMessage] - usage: _usage.Usage + usage: _usage.RunUsage retries: int run_step: int @@ -357,7 +357,7 @@ async def _make_request( ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) ) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) - ctx.state.usage.incr(_usage.Usage()) + ctx.state.usage.incr(_usage.RunUsage()) return self._finish_handling(ctx, model_response) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index afad0e60e6..64694d7768 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from .models import Model - from .result import Usage + from .result import RunUsage AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) """Type variable for agent dependencies.""" @@ -26,7 +26,7 @@ class RunContext(Generic[AgentDepsT]): """Dependencies for the agent.""" model: Model """The model used in this run.""" - usage: Usage + usage: RunUsage """LLM usage associated with the run.""" prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 447a4ba60d..1b86a3bf3a 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -104,7 +104,7 @@ from .tools import AgentDepsT, ToolDefinition from .toolsets import AbstractToolset from .toolsets.deferred import DeferredToolset -from .usage import Usage, UsageLimits +from .usage import RunUsage, UsageLimits __all__ = [ 'SSE_CONTENT_TYPE', @@ -130,7 +130,7 @@ def __init__( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, # Starlette parameters. @@ -242,7 +242,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncGenerator[str, None]: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 5f22d73294..95b49f0978 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -55,7 +55,6 @@ from .toolsets.combined import CombinedToolset from .toolsets.function import FunctionToolset from .toolsets.prepared import PreparedToolset -from .usage import Usage, UsageLimits # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -451,7 +450,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @@ -467,7 +466,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -484,7 +483,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -499,7 +498,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -576,7 +575,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -593,7 +592,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -611,7 +610,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @@ -627,7 +626,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -680,7 +679,7 @@ async def main(): model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( - requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + requests=1, input_tokens=56, output_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), @@ -746,7 +745,7 @@ async def main(): ) # Build the initial state - usage = usage or _usage.Usage() + usage = usage or _usage.RunUsage() state = _agent_graph.GraphAgentState( message_history=message_history[:] if message_history else [], usage=usage, @@ -855,15 +854,14 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: finally: try: if instrumentation_settings and run_span.is_recording(): - run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) + run_span.set_attributes(self._run_span_end_attributes(state, instrumentation_settings)) finally: run_span.end() def _run_span_end_attributes( - self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings - ): + self, state: _agent_graph.GraphAgentState, settings: InstrumentationSettings + ) -> dict[str, str | int]: return { - **usage.opentelemetry_attributes(), 'all_messages_events': json.dumps( [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] ), @@ -888,7 +886,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @@ -904,7 +902,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -921,7 +919,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @@ -936,7 +934,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -1009,7 +1007,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @@ -1025,7 +1023,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @@ -1042,7 +1040,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @@ -1058,7 +1056,7 @@ async def run_stream( # noqa C901 deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -1856,8 +1854,8 @@ def to_ag_ui( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, # Starlette @@ -2091,7 +2089,7 @@ async def main(): model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( - requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + requests=1, input_tokens=56, output_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), @@ -2229,8 +2227,8 @@ async def main(): parts=[TextPart(content='Paris')], usage=Usage( requests=1, - request_tokens=56, - response_tokens=1, + input_tokens=56, + output_tokens=1, total_tokens=57, ), model_name='gpt-4o', @@ -2259,7 +2257,7 @@ async def main(): assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node - def usage(self) -> _usage.Usage: + def usage(self) -> _usage.RunUsage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage @@ -2421,6 +2419,6 @@ def new_messages_json( content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content)) - def usage(self) -> _usage.Usage: + def usage(self) -> _usage.RunUsage: """Return the usage of the whole run.""" return self._state.usage diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 6735c928e0..aaeae4f501 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -16,7 +16,7 @@ from datetime import datetime from types import TracebackType -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from pydantic_graph._utils import get_event_loop as _get_event_loop from . import agent, messages, models, settings @@ -57,7 +57,7 @@ async def main(): ''' ModelResponse( parts=[TextPart(content='Paris')], - usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57), + usage=Usage(requests=1, input_tokens=56, output_tokens=1, total_tokens=57), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) @@ -110,7 +110,7 @@ def model_request_sync( ''' ModelResponse( parts=[TextPart(content='Paris')], - usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57), + usage=Usage(requests=1, input_tokens=56, output_tokens=1, total_tokens=57), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) @@ -366,7 +366,7 @@ def get(self) -> messages.ModelResponse: """Build a ModelResponse from the data received from the stream so far.""" return self._ensure_stream_ready().get() - def usage(self) -> Usage: + def usage(self) -> RunUsage: """Get the usage of the response so far.""" return self._ensure_stream_ready().usage() diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index b5d7be2857..91019f920c 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -10,6 +10,7 @@ import pydantic import pydantic_core +from genai_prices import calc_price_sync, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] from typing_extensions import TypeAlias, deprecated @@ -19,7 +20,7 @@ now_utc as _now_utc, ) from .exceptions import UnexpectedModelBehavior -from .usage import Usage +from .usage import RequestUsage if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings @@ -764,7 +765,7 @@ class ModelResponse: parts: list[ModelResponsePart] """The parts of the model message.""" - usage: Usage = field(default_factory=Usage) + usage: RequestUsage = field(default_factory=RequestUsage) """Usage information for the request. This has a default to make tests easier, and to support loading old messages where usage will be missing. @@ -779,18 +780,31 @@ class ModelResponse: If the model provides a timestamp in the response (as OpenAI does) that will be used. """ - kind: Literal['response'] = 'response' - """Message type identifier, this is available on all parts as a discriminator.""" + provider_name: str | None = None + """The name of the LLM provider that generated the response.""" - vendor_details: dict[str, Any] | None = field(default=None) - """Additional vendor-specific details in a serializable format. + provider_details: dict[str, Any] | None = field(default=None) + """Additional provider-specific details in a serializable format. This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. For OpenAI models, this may include 'logprobs', 'finish_reason', etc. """ - vendor_id: str | None = None - """Vendor ID as specified by the model provider. This can be used to track the specific request to the model.""" + provider_request_id: str | None = None + """request ID as specified by the model provider. This can be used to track the specific request to the model.""" + + kind: Literal['response'] = 'response' + """Message type identifier, this is available on all parts as a discriminator.""" + + def price(self) -> genai_types.PriceCalculation: + """Calculate the price of the usage, this doesn't use `auto_update` so won't make any network requests.""" + assert self.model_name, 'Model name is required to calculate price' + return calc_price_sync( + self.usage, + self.model_name, + provider_id=self.provider_name, + genai_request_timestamp=self.timestamp, + ) def otel_events(self, settings: InstrumentationSettings) -> list[Event]: """Return OpenTelemetry events for the response.""" @@ -828,6 +842,16 @@ def new_event_body(): return result + @property + @deprecated('`vendor_details` is deprecated, use `provider_details` instead') + def vendor_details(self) -> dict[str, Any] | None: + return self.provider_details + + @property + @deprecated('`vendor_id` is deprecated, use `provider_request_id` instead') + def vendor_id(self) -> str | None: + return self.provider_request_id + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6cdcbfbd64..6183ababca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -29,7 +29,7 @@ from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings from ..tools import ToolDefinition -from ..usage import Usage +from ..usage import RequestUsage KnownModelName = TypeAliasType( 'KnownModelName', @@ -502,7 +502,7 @@ class StreamedResponse(ABC): _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - _usage: Usage = field(default_factory=Usage, init=False) + _usage: RequestUsage = field(default_factory=RequestUsage, init=False) def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" @@ -532,7 +532,7 @@ def get(self) -> ModelResponse: usage=self.usage(), ) - def usage(self) -> Usage: + def usage(self) -> RequestUsage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 02f9111c2d..78a0aa2d31 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -282,7 +282,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: ) ) - return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id) + return ModelResponse( + items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id + ) async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageStreamEvent]) -> StreamedResponse: peekable_response = _utils.PeekableAsyncStream(response) @@ -424,7 +426,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam: } -def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage: +def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage: if isinstance(message, BetaMessage): response_usage = message.usage elif isinstance(message, BetaRawMessageStartEvent): @@ -437,7 +439,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage: # - RawContentBlockStartEvent # - RawContentBlockDeltaEvent # - RawContentBlockStopEvent - return usage.Usage() + return usage.RequestUsage() # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by # `response_tokens` @@ -454,9 +456,9 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage: + details.get('cache_read_input_tokens', 0) ) - return usage.Usage( - request_tokens=request_tokens or None, - response_tokens=response_usage.output_tokens, + return usage.RequestUsage( + input_tokens=request_tokens or None, + output_tokens=response_usage.output_tokens, total_tokens=request_tokens + response_usage.output_tokens, details=details or None, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index b63ed4e1f9..07c08ba024 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -290,13 +290,13 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes tool_call_id=tool_use['toolUseId'], ), ) - u = usage.Usage( - request_tokens=response['usage']['inputTokens'], - response_tokens=response['usage']['outputTokens'], + u = usage.RequestUsage( + input_tokens=response['usage']['inputTokens'], + output_tokens=response['usage']['outputTokens'], total_tokens=response['usage']['totalTokens'], ) vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None) - return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id) + return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id) @overload async def _messages_create( @@ -641,10 +641,10 @@ def model_name(self) -> str: """Get the model name of the response.""" return self._model_name - def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.Usage: - return usage.Usage( - request_tokens=metadata['usage']['inputTokens'], - response_tokens=metadata['usage']['outputTokens'], + def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage: + return usage.RequestUsage( + input_tokens=metadata['usage']['inputTokens'], + output_tokens=metadata['usage']['outputTokens'], total_tokens=metadata['usage']['totalTokens'], ) diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 4243ef492a..a113b10000 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -294,10 +294,10 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: assert_never(part) -def _map_usage(response: V2ChatResponse) -> usage.Usage: +def _map_usage(response: V2ChatResponse) -> usage.RequestUsage: u = response.usage if u is None: - return usage.Usage() + return usage.RequestUsage() else: details: dict[str, int] = {} if u.billed_units is not None: @@ -312,9 +312,9 @@ def _map_usage(response: V2ChatResponse) -> usage.Usage: request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else None response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else None - return usage.Usage( - request_tokens=request_tokens, - response_tokens=response_tokens, + return usage.RequestUsage( + input_tokens=request_tokens, + output_tokens=response_tokens, total_tokens=(request_tokens or 0) + (response_tokens or 0), details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index c48873f046..68cb5fa4b4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -263,7 +263,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._usage += usage.RequestUsage(output_tokens=response_tokens, total_tokens=response_tokens) maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) if maybe_event is not None: # pragma: no branch yield maybe_event @@ -272,7 +272,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(delta, DeltaThinkingPart): if delta.content: # pragma: no branch response_tokens = _estimate_string_tokens(delta.content) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._usage += usage.RequestUsage( + output_tokens=response_tokens, total_tokens=response_tokens + ) yield self._parts_manager.handle_thinking_delta( vendor_part_id=dtc_index, content=delta.content, @@ -281,7 +283,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(delta, DeltaToolCall): if delta.json_args: response_tokens = _estimate_string_tokens(delta.json_args) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._usage += usage.RequestUsage( + output_tokens=response_tokens, total_tokens=response_tokens + ) maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc_index, tool_name=delta.name, @@ -304,7 +308,7 @@ def timestamp(self) -> datetime: return self._timestamp -def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: +def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage: """Very rough guesstimate of the token usage associated with a series of messages. This is designed to be used solely to give plausible numbers for testing! @@ -335,9 +339,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: assert_never(part) else: assert_never(message) - return usage.Usage( - request_tokens=request_tokens, - response_tokens=response_tokens, + return usage.RequestUsage( + input_tokens=request_tokens, + output_tokens=response_tokens, total_tokens=request_tokens + response_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4ac07f8ada..9c2ca5b4b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -281,8 +281,8 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: parts, response.get('model_version', self._model_name), usage, - vendor_id=vendor_id, - vendor_details=vendor_details, + provider_request_id=vendor_id, + provider_details=vendor_details, ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: @@ -661,7 +661,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart def _process_response_from_parts( parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, - usage: usage.Usage, + usage: usage.RequestUsage, vendor_id: str | None, vendor_details: dict[str, Any] | None = None, ) -> ModelResponse: @@ -681,7 +681,7 @@ def _process_response_from_parts( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) return ModelResponse( - parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details + parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details ) @@ -847,10 +847,10 @@ class _GeminiUsageMetaData(TypedDict, total=False): ] -def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: +def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage: metadata = response.get('usage_metadata') if metadata is None: - return usage.Usage() # pragma: no cover + return usage.RequestUsage() # pragma: no cover details: dict[str, int] = {} if cached_content_token_count := metadata.get('cached_content_token_count'): details['cached_content_tokens'] = cached_content_token_count # pragma: no cover @@ -868,9 +868,9 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: for detail in metadata_details: details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] - return usage.Usage( - request_tokens=metadata.get('prompt_token_count', 0), - response_tokens=metadata.get('candidates_token_count', 0), + return usage.RequestUsage( + input_tokens=metadata.get('prompt_token_count', 0), + output_tokens=metadata.get('candidates_token_count', 0), total_tokens=metadata.get('total_token_count', 0), details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 082f5ba566..dfc37919f1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -322,7 +322,11 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( - parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details + parts, + response.model_version or self._model_name, + usage, + provider_request_id=vendor_id, + provider_details=vendor_details, ) async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: @@ -505,7 +509,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict: def _process_response_from_parts( parts: list[Part], model_name: GoogleModelName, - usage: usage.Usage, + usage: usage.RequestUsage, vendor_id: str | None, vendor_details: dict[str, Any] | None = None, ) -> ModelResponse: @@ -527,7 +531,7 @@ def _process_response_from_parts( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) return ModelResponse( - parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details + parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details ) @@ -547,10 +551,10 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict: return ToolConfigDict(function_calling_config=function_calling_config) -def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage: +def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage: metadata = response.usage_metadata if metadata is None: - return usage.Usage() # pragma: no cover + return usage.RequestUsage() # pragma: no cover metadata = metadata.model_dump(exclude_defaults=True) details: dict[str, int] = {} @@ -569,9 +573,9 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage: for detail in metadata_details: details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] - return usage.Usage( - request_tokens=metadata.get('prompt_token_count', 0), - response_tokens=metadata.get('candidates_token_count', 0), + return usage.RequestUsage( + input_tokens=metadata.get('prompt_token_count', 0), + output_tokens=metadata.get('candidates_token_count', 0), total_tokens=metadata.get('total_token_count', 0), details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index ffca84b447..23c2e395f7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -266,7 +266,11 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: for c in choice.message.tool_calls: items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) return ModelResponse( - items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + provider_request_id=response.id, ) async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse: @@ -443,7 +447,7 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage: +def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage: response_usage = None if isinstance(completion, chat.ChatCompletion): response_usage = completion.usage @@ -451,10 +455,10 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us response_usage = completion.x_groq.usage if response_usage is None: - return usage.Usage() + return usage.RequestUsage() - return usage.Usage( - request_tokens=response_usage.prompt_tokens, - response_tokens=response_usage.completion_tokens, + return usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 4b3c2ff404..2e1488ac1e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -253,7 +253,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: usage=_map_usage(response), model_name=response.model, timestamp=timestamp, - vendor_id=response.id, + provider_request_id=response.id, ) async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse: @@ -454,14 +454,14 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage: +def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage: response_usage = response.usage if response_usage is None: - return usage.Usage() + return usage.RequestUsage() - return usage.Usage( - request_tokens=response_usage.prompt_tokens, - response_tokens=response_usage.completion_tokens, + return usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, details=None, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index ebfaac92d0..d7643e68cc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -62,7 +62,7 @@ async def request( if result.role == 'assistant': return ModelResponse( parts=[_mcp.map_from_sampling_content(result.content)], - usage=usage.Usage(requests=1), + usage=usage.RequestUsage(requests=1), model_name=result.model, ) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ca73558bca..94fdfa37e5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -37,7 +37,7 @@ from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition -from ..usage import Usage +from ..usage import RequestUsage from . import ( Model, ModelRequestParameters, @@ -341,7 +341,11 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes parts.append(tool) return ModelResponse( - parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id + parts, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + provider_request_id=response.id, ) async def _process_streamed_response( @@ -689,17 +693,17 @@ def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[ } -def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage: +def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage: """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage.""" if response.usage: - return Usage( - request_tokens=response.usage.prompt_tokens, - response_tokens=response.usage.completion_tokens, + return RequestUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, details=None, ) else: - return Usage() # pragma: no cover + return RequestUsage() # pragma: no cover def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 35dca2e03d..fc14549cef 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -418,8 +418,8 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons usage=_map_usage(response), model_name=response.model, timestamp=timestamp, - vendor_details=vendor_details, - vendor_id=response.id, + provider_details=vendor_details, + provider_request_id=response.id, ) async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: @@ -706,7 +706,7 @@ def _process_response(self, response: responses.Response) -> ModelResponse: items, usage=_map_usage(response), model_name=response.model, - vendor_id=response.id, + provider_request_id=response.id, timestamp=timestamp, ) @@ -1170,10 +1170,10 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.Usage: +def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.RequestUsage: response_usage = response.usage if response_usage is None: - return usage.Usage() + return usage.RequestUsage() elif isinstance(response_usage, responses.ResponseUsage): details: dict[str, int] = { key: value @@ -1183,29 +1183,24 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R if isinstance(value, int) } details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens - details['cached_tokens'] = response_usage.input_tokens_details.cached_tokens - return usage.Usage( - request_tokens=response_usage.input_tokens, - response_tokens=response_usage.output_tokens, - total_tokens=response_usage.total_tokens, + return usage.RequestUsage( + input_tokens=response_usage.input_tokens, + output_tokens=response_usage.output_tokens, + cache_read_tokens=response_usage.input_tokens_details.cached_tokens, details=details, ) else: details = { - key: value - for key, value in response_usage.model_dump( - exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'} - ).items() - if isinstance(value, int) + key: value for key, value in response_usage.model_dump(exclude_none=True).items() if isinstance(value, int) } + u = usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, + details=details, + ) if response_usage.completion_tokens_details is not None: details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) if response_usage.prompt_tokens_details is not None: - details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True)) - return usage.Usage( - requests=1, - request_tokens=response_usage.prompt_tokens, - response_tokens=response_usage.completion_tokens, - total_tokens=response_usage.total_tokens, - details=details, - ) + u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens + u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens + return u diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index eebe00d440..562f5bea8c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -27,7 +27,7 @@ from ..profiles import ModelProfileSpec from ..settings import ModelSettings from ..tools import ToolDefinition -from ..usage import Usage +from ..usage import RequestUsage from . import Model, ModelRequestParameters, StreamedResponse from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage] @@ -454,6 +454,6 @@ def _char(self) -> str: return s -def _get_string_usage(text: str) -> Usage: +def _get_string_usage(text: str) -> RequestUsage: response_tokens = _estimate_string_tokens(text) - return Usage(response_tokens=response_tokens, total_tokens=response_tokens) + return RequestUsage(output_tokens=response_tokens, total_tokens=response_tokens) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 2dc3eb8259..c0f32bff0f 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -28,7 +28,7 @@ OutputDataT, ToolOutput, ) -from .usage import Usage, UsageLimits +from .usage import RunUsage, UsageLimits __all__ = ( 'OutputDataT', @@ -53,7 +53,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _final_result_event: FinalResultEvent | None = field(default=None, init=False) - _initial_run_ctx_usage: Usage = field(init=False) + _initial_run_ctx_usage: RunUsage = field(init=False) def __post_init__(self): self._initial_run_ctx_usage = copy(self._run_ctx.usage) @@ -111,7 +111,7 @@ def get(self) -> _messages.ModelResponse: """Get the current state of the response.""" return self._raw_stream_response.get() - def usage(self) -> Usage: + def usage(self) -> RunUsage: """Return the usage of the whole run. !!! note @@ -464,7 +464,7 @@ async def get_output(self) -> OutputDataT: async def get_data(self) -> OutputDataT: return await self.get_output() - def usage(self) -> Usage: + def usage(self) -> RunUsage: """Return the usage of the whole run. !!! note @@ -518,7 +518,7 @@ def data(self) -> OutputDataT: def _get_usage_checking_stream_response( stream_response: AsyncIterable[_messages.ModelResponseStreamEvent], limits: UsageLimits | None, - get_usage: Callable[[], Usage], + get_usage: Callable[[], RunUsage], ) -> AsyncIterable[_messages.ModelResponseStreamEvent]: if limits is not None and limits.has_token_limits(): diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index c3f4c1885b..cce47eaade 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -1,55 +1,79 @@ from __future__ import annotations as _annotations +import dataclasses from copy import copy from dataclasses import dataclass +from typing_extensions import deprecated, overload + from . import _utils from .exceptions import UsageLimitExceeded -__all__ = 'Usage', 'UsageLimits' +__all__ = 'RequestUsage', 'RunUsage', 'UsageLimits' @dataclass(repr=False) -class Usage: - """LLM usage associated with a request or run. +class RequestUsage: + """LLM usage associated with a single request. - Responsibility for calculating usage is on the model; Pydantic AI simply sums the usage information across requests. + This is an implementation of `genai_prices.types.AbstractUsage` so it can be used to calculate the price of the + request. - You'll need to look up the documentation of the model you're using to convert usage to monetary costs. + Prices for LLM requests are calculated using [genai-prices](https://github.com/pydantic/genai-prices). """ - requests: int = 0 - """Number of requests made to the LLM API.""" - request_tokens: int | None = None - """Tokens used in processing requests.""" - response_tokens: int | None = None - """Tokens used in generating responses.""" - total_tokens: int | None = None - """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" + input_tokens: int | None = None + """Number of text input/prompt tokens.""" + + cache_write_tokens: int | None = None + """Number of tokens written to the cache.""" + cache_read_tokens: int | None = None + """Number of tokens read from the cache.""" + + output_tokens: int | None = None + """Number of text output/completion tokens.""" + + input_audio_tokens: int | None = None + """Number of audio input tokens.""" + cache_audio_read_tokens: int | None = None + """Number of audio tokens read from the cache.""" + details: dict[str, int] | None = None """Any extra details returned by the model.""" - def incr(self, incr_usage: Usage) -> None: + # not used but present so RequestUsage is a valid AbstractUsage + output_audio_tokens: None = dataclasses.field(default=None, init=False) + requests: int = dataclasses.field(default=1, init=False) + + @property + @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') + def request_tokens(self) -> int | None: + return self.input_tokens + + @property + @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') + def response_tokens(self) -> int | None: + return self.output_tokens + + @property + @deprecated('`total_tokens` is deprecated, sum the specific fields you need instead') + def total_tokens(self) -> int | None: + return sum(v for k, v in dataclasses.asdict(self).values() if k.endswith('_tokens') and v is not None) + + def incr(self, incr_usage: RequestUsage) -> None: """Increment the usage in place. Args: incr_usage: The usage to increment by. """ - for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': - self_value = getattr(self, f) - other_value = getattr(incr_usage, f) - if self_value is not None or other_value is not None: - setattr(self, f, (self_value or 0) + (other_value or 0)) + return _incr_usage_tokens(self, incr_usage) - if incr_usage.details: - self.details = self.details or {} - for key, value in incr_usage.details.items(): - self.details[key] = self.details.get(key, 0) + value + def __add__(self, other: RequestUsage) -> RequestUsage: + """Add two RequestUsages together. - def __add__(self, other: Usage) -> Usage: - """Add two Usages together. + This is provided so it's trivial to sum usage information from multiple parts of a response. - This is provided so it's trivial to sum usage information from multiple requests and runs. + **WARNING:** this CANNOT be used to sum multiple requests without breaking some pricing calculations. """ new_usage = copy(self) new_usage.incr(other) @@ -58,10 +82,10 @@ def __add__(self, other: Usage) -> Usage: def opentelemetry_attributes(self) -> dict[str, int]: """Get the token limits as OpenTelemetry attributes.""" result: dict[str, int] = {} - if self.request_tokens: - result['gen_ai.usage.input_tokens'] = self.request_tokens - if self.response_tokens: - result['gen_ai.usage.output_tokens'] = self.response_tokens + if self.input_tokens: + result['gen_ai.usage.input_tokens'] = self.input_tokens + if self.output_tokens: + result['gen_ai.usage.output_tokens'] = self.output_tokens details = self.details if details: prefix = 'gen_ai.usage.details.' @@ -71,13 +95,118 @@ def opentelemetry_attributes(self) -> dict[str, int]: result[prefix + key] = value return result + __repr__ = _utils.dataclasses_no_defaults_repr + + +@dataclass(repr=False) +class RunUsage: + """LLM usage associated with an agent run. + + Responsibility for calculating request usage is on the model; Pydantic AI simply sums the usage information across requests. + """ + + requests: int = 0 + """Number of requests made to the LLM API.""" + + input_tokens: int | None = None + """Total number of text input/prompt tokens.""" + + cache_write_tokens: int | None = None + """Total number of tokens written to the cache.""" + cache_read_tokens: int | None = None + """Total number of tokens read from the cache.""" + + input_audio_tokens: int | None = None + """Total number of audio input tokens.""" + cache_audio_read_tokens: int | None = None + """Total number of audio tokens read from the cache.""" + + output_tokens: int | None = None + """Total number of text output/completion tokens.""" + + details: dict[str, int] | None = None + """Any extra details returned by the model.""" + + def input_output_tokens(self) -> int | None: + """Sum of `input_tokens + output_tokens`.""" + if self.input_tokens is None and self.output_tokens is None: + return None + else: + return (self.input_tokens or 0) + (self.output_tokens or 0) + + @property + @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') + def request_tokens(self) -> int | None: + return self.input_tokens + + @property + @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') + def response_tokens(self) -> int | None: + return self.output_tokens + + @property + @deprecated('`total_tokens` is deprecated, sum the specific fields you need or use `input_output_tokens` instead') + def total_tokens(self) -> int | None: + return sum(v for k, v in dataclasses.asdict(self).values() if k.endswith('_tokens') and v is not None) + + def incr(self, incr_usage: RunUsage | RequestUsage) -> None: + """Increment the usage in place. + + Args: + incr_usage: The usage to increment by. + """ + if isinstance(incr_usage, RunUsage): + self.requests += incr_usage.requests + return _incr_usage_tokens(self, incr_usage) + + def __add__(self, other: RunUsage) -> RunUsage: + """Add two RunUsages together. + + This is provided so it's trivial to sum usage information from multiple runs. + """ + new_usage = copy(self) + new_usage.incr(other) + return new_usage + def has_values(self) -> bool: """Whether any values are set and non-zero.""" - return bool(self.requests or self.request_tokens or self.response_tokens or self.details) + return any(dataclasses.asdict(self).values()) __repr__ = _utils.dataclasses_no_defaults_repr +def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | RequestUsage) -> None: + """Increment the usage in place. + + Args: + slf: The usage to increment. + incr_usage: The usage to increment by. + """ + if incr_usage.input_tokens: + slf.input_tokens = (slf.input_tokens or 0) + incr_usage.input_tokens + if incr_usage.output_tokens: + slf.output_tokens = (slf.output_tokens or 0) + incr_usage.output_tokens + if incr_usage.cache_write_tokens: + slf.cache_write_tokens = (slf.cache_write_tokens or 0) + incr_usage.cache_write_tokens + if incr_usage.cache_read_tokens: + slf.cache_read_tokens = (slf.cache_read_tokens or 0) + incr_usage.cache_read_tokens + if incr_usage.input_audio_tokens: + slf.input_audio_tokens = (slf.input_audio_tokens or 0) + incr_usage.input_audio_tokens + if incr_usage.cache_audio_read_tokens: + slf.cache_audio_read_tokens = (slf.cache_audio_read_tokens or 0) + incr_usage.cache_audio_read_tokens + + if incr_usage.details: + slf.details = slf.details or {} + for key, value in incr_usage.details.items(): + slf.details[key] = slf.details.get(key, 0) + value + + +@dataclass +@deprecated('`Usage` is deprecated, use `RunUsage` instead') +class Usage(RunUsage): + pass + + @dataclass(repr=False) class UsageLimits: """Limits on model usage. @@ -90,12 +219,59 @@ class UsageLimits: request_limit: int | None = 50 """The maximum number of requests allowed to the model.""" - request_tokens_limit: int | None = None - """The maximum number of tokens allowed in requests to the model.""" - response_tokens_limit: int | None = None - """The maximum number of tokens allowed in responses from the model.""" + input_tokens_limit: int | None = None + """The maximum number of input/prompt tokens allowed.""" + output_tokens_limit: int | None = None + """The maximum number of output/response tokens allowed.""" total_tokens_limit: int | None = None - """The maximum number of tokens allowed in requests and responses combined.""" + """The maximum number of combined input and output tokens allowed.""" + + @overload + def __init__( + self, + *, + request_limit: int | None = 50, + input_tokens_limit: int | None = None, + output_tokens_limit: int | None = None, + total_tokens_limit: int | None = None, + ) -> None: + self.request_limit = request_limit + self.input_tokens_limit = input_tokens_limit + self.output_tokens_limit = output_tokens_limit + self.total_tokens_limit = total_tokens_limit + + @overload + @deprecated( + 'Use `input_tokens_limit` instead of `request_tokens_limit` and `output_tokens_limit` and `total_tokens_limit`' + ) + def __init__( + self, + *, + request_limit: int | None = 50, + request_tokens_limit: int | None = None, + response_tokens_limit: int | None = None, + total_tokens_limit: int | None = None, + ) -> None: + self.request_limit = request_limit + self.input_tokens_limit = request_tokens_limit + self.output_tokens_limit = response_tokens_limit + self.total_tokens_limit = total_tokens_limit + + def __init__( + self, + *, + request_limit: int | None = 50, + input_tokens_limit: int | None = None, + output_tokens_limit: int | None = None, + total_tokens_limit: int | None = None, + # deprecated: + request_tokens_limit: int | None = None, + response_tokens_limit: int | None = None, + ): + self.request_limit = request_limit + self.input_tokens_limit = input_tokens_limit or request_tokens_limit + self.output_tokens_limit = output_tokens_limit or response_tokens_limit + self.total_tokens_limit = total_tokens_limit def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. @@ -106,31 +282,28 @@ def has_token_limits(self) -> bool: If there are no limits, we can skip that processing in the streaming response iterator. """ return any( - limit is not None - for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) + limit is not None for limit in (self.input_tokens_limit, self.output_tokens_limit, self.total_tokens_limit) ) - def check_before_request(self, usage: Usage) -> None: + def check_before_request(self, usage: RunUsage) -> None: """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit.""" request_limit = self.request_limit if request_limit is not None and usage.requests >= request_limit: raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') - def check_tokens(self, usage: Usage) -> None: + def check_tokens(self, usage: RunUsage) -> None: """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" - request_tokens = usage.request_tokens or 0 - if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: - raise UsageLimitExceeded( - f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' - ) + input_tokens = usage.input_tokens or 0 + if self.input_tokens_limit is not None and input_tokens > self.input_tokens_limit: + raise UsageLimitExceeded(f'Exceeded the input_tokens_limit of {self.input_tokens_limit} ({input_tokens=})') - response_tokens = usage.response_tokens or 0 - if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: + output_tokens = usage.output_tokens or 0 + if self.output_tokens_limit is not None and output_tokens > self.output_tokens_limit: raise UsageLimitExceeded( - f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' + f'Exceeded the output_tokens_limit of {self.output_tokens_limit} ({output_tokens=})' ) - total_tokens = usage.total_tokens or 0 + total_tokens = usage.input_output_tokens() or 0 if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 7407e14e2a..67a1e61ba5 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -46,6 +46,12 @@ classifiers = [ ] requires-python = ">=3.9" +[project.urls] +Homepage = "https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim" +Source = "https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim" +Documentation = "https://ai.pydantic.dev/install/#slim-install" +Changelog = "https://github.com/pydantic/pydantic-ai/releases" + [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ "eval-type-backport>=0.2.0", @@ -56,6 +62,7 @@ dependencies = [ "exceptiongroup; python_version < '3.11'", "opentelemetry-api>=1.28.0", "typing-inspection>=0.4.0", + "genai-prices>=0.0.3", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] @@ -111,7 +118,7 @@ dev = [ allow-direct-references = true [project.scripts] -pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while +pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while [tool.hatch.build.targets.wheel] packages = ["pydantic_ai"] diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 77857e8821..846a4e84dd 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -35,7 +35,7 @@ UserPromptPart, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, raise_if_exception, try_import @@ -163,10 +163,10 @@ async def test_sync_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=5, - response_tokens=10, + input_tokens=5, + output_tokens=10, total_tokens=15, details={'input_tokens': 5, 'output_tokens': 10}, ) @@ -177,10 +177,10 @@ async def test_sync_request_text_response(allow_model_requests: None): result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=5, - response_tokens=10, + input_tokens=5, + output_tokens=10, total_tokens=15, details={'input_tokens': 5, 'output_tokens': 10}, ) @@ -190,30 +190,30 @@ async def test_sync_request_text_response(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=5, - response_tokens=10, + input_tokens=5, + output_tokens=10, total_tokens=15, details={'input_tokens': 5, 'output_tokens': 10}, ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=5, - response_tokens=10, + input_tokens=5, + output_tokens=10, total_tokens=15, details={'input_tokens': 5, 'output_tokens': 10}, ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -236,10 +236,10 @@ async def test_async_request_prompt_caching(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=13, - response_tokens=5, + input_tokens=13, + output_tokens=5, total_tokens=18, details={ 'input_tokens': 3, @@ -263,10 +263,10 @@ async def test_async_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=3, - response_tokens=5, + input_tokens=3, + output_tokens=5, total_tokens=8, details={'input_tokens': 3, 'output_tokens': 5}, ) @@ -295,16 +295,16 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=3, - response_tokens=5, + input_tokens=3, + output_tokens=5, total_tokens=8, details={'input_tokens': 3, 'output_tokens': 5}, ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -365,16 +365,16 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=2, - response_tokens=1, + input_tokens=2, + output_tokens=1, total_tokens=3, details={'input_tokens': 2, 'output_tokens': 1}, ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -394,16 +394,16 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=3, - response_tokens=2, + input_tokens=3, + output_tokens=2, total_tokens=5, details={'input_tokens': 3, 'output_tokens': 2}, ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -417,16 +417,16 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=3, - response_tokens=5, + input_tokens=3, + output_tokens=5, total_tokens=8, details={'input_tokens': 3, 'output_tokens': 5}, ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -680,10 +680,10 @@ async def my_tool(first: str, second: str) -> int: ) assert result.is_complete assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=20, - response_tokens=5, + input_tokens=20, + output_tokens=5, total_tokens=25, details={'input_tokens': 20, 'output_tokens': 5}, ) @@ -761,10 +761,10 @@ async def get_image() -> BinaryContent: TextPart(content='Let me get the image and check.'), ToolCallPart(tool_name='get_image', args={}, tool_call_id='toolu_01YJiJ82nETV7aRdJr9f6Np7'), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=372, - response_tokens=45, + input_tokens=372, + output_tokens=45, total_tokens=417, details={ 'cache_creation_input_tokens': 0, @@ -775,7 +775,7 @@ async def get_image() -> BinaryContent: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01CC59GmUmYXKCV26rHfr32m', + provider_request_id='msg_01CC59GmUmYXKCV26rHfr32m', ), ModelRequest( parts=[ @@ -800,10 +800,10 @@ async def get_image() -> BinaryContent: content="The image shows a kiwi fruit that has been cut in half, displaying its characteristic bright green flesh with small black seeds arranged in a circular pattern around a white center core. The fruit's thin, fuzzy brown skin is visible around the edges of the slice." ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=2021, - response_tokens=57, + input_tokens=2021, + output_tokens=57, total_tokens=2078, details={ 'cache_creation_input_tokens': 0, @@ -814,7 +814,7 @@ async def get_image() -> BinaryContent: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_014MJqSsWD1pUC23Vvi57LoY', + provider_request_id='msg_014MJqSsWD1pUC23Vvi57LoY', ), ] ) @@ -933,10 +933,10 @@ def simple_instructions(): ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=20, - response_tokens=10, + input_tokens=20, + output_tokens=10, total_tokens=30, details={ 'cache_creation_input_tokens': 0, @@ -947,7 +947,7 @@ def simple_instructions(): ), model_name='claude-3-opus-20240229', timestamp=IsDatetime(), - vendor_id='msg_01BznVNBje2zyfpCfNQCD5en', + provider_request_id='msg_01BznVNBje2zyfpCfNQCD5en', ), ] ) @@ -970,10 +970,10 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), TextPart(content=IsStr()), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=42, - response_tokens=302, + input_tokens=42, + output_tokens=302, total_tokens=344, details={ 'cache_creation_input_tokens': 0, @@ -984,7 +984,7 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), - vendor_id='msg_01FWiSVNCRHvHUYU21BRandY', + provider_request_id='msg_01FWiSVNCRHvHUYU21BRandY', ), ] ) @@ -998,10 +998,10 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=42, - response_tokens=302, + input_tokens=42, + output_tokens=302, total_tokens=344, details={ 'cache_creation_input_tokens': 0, @@ -1012,7 +1012,7 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1041,10 +1041,10 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), TextPart(content=IsStr()), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=303, - response_tokens=486, + input_tokens=303, + output_tokens=486, total_tokens=789, details={ 'cache_creation_input_tokens': 0, @@ -1055,7 +1055,7 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ] ) @@ -1227,8 +1227,8 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: pytest.param( lambda: anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), snapshot( - Usage( - request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} + RunUsage( + input_tokens=1, output_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} ) ), id='AnthropicMessage', @@ -1238,9 +1238,9 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: BetaUsage(input_tokens=1, output_tokens=1, cache_creation_input_tokens=2, cache_read_input_tokens=3) ), snapshot( - Usage( - request_tokens=6, - response_tokens=1, + RunUsage( + input_tokens=6, + output_tokens=1, total_tokens=7, details={ 'cache_creation_input_tokens': 2, @@ -1257,8 +1257,8 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start' ), snapshot( - Usage( - request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} + RunUsage( + input_tokens=1, output_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} ) ), id='RawMessageStartEvent', @@ -1269,13 +1269,15 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta', ), - snapshot(Usage(response_tokens=5, total_tokens=5, details={'output_tokens': 5})), + snapshot(RunUsage(output_tokens=5, total_tokens=5, details={'output_tokens': 5})), id='RawMessageDeltaEvent', ), - pytest.param(lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(Usage()), id='RawMessageStopEvent'), + pytest.param( + lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(RunUsage()), id='RawMessageStopEvent' + ), ], ) -def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: Usage): +def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: RunUsage): assert _map_usage(message_callback()) == usage @@ -1386,10 +1388,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_019pMboNVRg5jkw4PKkofQ6Y') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=445, - response_tokens=23, + input_tokens=445, + output_tokens=23, total_tokens=468, details={ 'cache_creation_input_tokens': 0, @@ -1400,7 +1402,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01EnfsDTixCmHjqvk9QarBj4', + provider_request_id='msg_01EnfsDTixCmHjqvk9QarBj4', ), ModelRequest( parts=[ @@ -1420,10 +1422,10 @@ async def get_user_country() -> str: tool_call_id='toolu_01V4d2H4EWp5LDM2aXaeyR6W', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=497, - response_tokens=56, + input_tokens=497, + output_tokens=56, total_tokens=553, details={ 'cache_creation_input_tokens': 0, @@ -1434,7 +1436,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01Hbm5BtKzfVtWs8Eb7rCNNx', + provider_request_id='msg_01Hbm5BtKzfVtWs8Eb7rCNNx', ), ModelRequest( parts=[ @@ -1486,10 +1488,10 @@ async def get_user_country() -> str: ), ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01EZuxfc6MsPsPgrAKQohw3e'), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=383, - response_tokens=66, + input_tokens=383, + output_tokens=66, total_tokens=449, details={ 'cache_creation_input_tokens': 0, @@ -1500,7 +1502,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_014NE4yfV1Yz2vLAJzapxxef', + provider_request_id='msg_014NE4yfV1Yz2vLAJzapxxef', ), ModelRequest( parts=[ @@ -1518,10 +1520,10 @@ async def get_user_country() -> str: content="Based on the result, you are located in Mexico. The largest city in Mexico is Mexico City (Ciudad de México), which is also the nation's capital. Mexico City has a population of approximately 9.2 million people in the city proper, and over 21 million people in its metropolitan area, making it one of the largest urban agglomerations in the world. It is both the political and economic center of Mexico, located in the Valley of Mexico in the central part of the country." ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=461, - response_tokens=107, + input_tokens=461, + output_tokens=107, total_tokens=568, details={ 'cache_creation_input_tokens': 0, @@ -1532,7 +1534,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_0193srwo7TCx49h97wDwc7K7', + provider_request_id='msg_0193srwo7TCx49h97wDwc7K7', ), ] ) @@ -1577,10 +1579,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_017UryVwtsKsjonhFV3cgV3X') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=459, - response_tokens=38, + input_tokens=459, + output_tokens=38, total_tokens=497, details={ 'cache_creation_input_tokens': 0, @@ -1591,7 +1593,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_014CpBKzioMqUyLWrMihpvsz', + provider_request_id='msg_014CpBKzioMqUyLWrMihpvsz', ), ModelRequest( parts=[ @@ -1612,10 +1614,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=510, - response_tokens=17, + input_tokens=510, + output_tokens=17, total_tokens=527, details={ 'cache_creation_input_tokens': 0, @@ -1626,7 +1628,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_014JeWCouH6DpdqzMTaBdkpJ', + provider_request_id='msg_014JeWCouH6DpdqzMTaBdkpJ', ), ] ) @@ -1671,10 +1673,10 @@ class CountryLanguage(BaseModel): content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=281, - response_tokens=31, + input_tokens=281, + output_tokens=31, total_tokens=312, details={ 'cache_creation_input_tokens': 0, @@ -1685,7 +1687,7 @@ class CountryLanguage(BaseModel): ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_013ttUi3HCcKt7PkJpoWs5FT', + provider_request_id='msg_013ttUi3HCcKt7PkJpoWs5FT', ), ] ) diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index fad3530758..154173d502 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -34,7 +34,7 @@ ) from pydantic_ai.models import ModelRequestParameters from pydantic_ai.tools import ToolDefinition -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -58,7 +58,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro assert result.output == snapshot( "Hello! How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help." ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=7, response_tokens=30, total_tokens=37)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=7, output_tokens=30, total_tokens=37)) assert result.all_messages() == snapshot( [ ModelRequest( @@ -79,7 +79,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro content="Hello! How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help." ) ], - usage=Usage(requests=1, request_tokens=7, response_tokens=30, total_tokens=37), + usage=RunUsage(requests=1, input_tokens=7, output_tokens=30, total_tokens=37), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -111,7 +111,7 @@ async def temperature(city: str, date: datetime.date) -> str: result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response) assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=1236, response_tokens=298, total_tokens=1534)) + assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298, total_tokens=1534)) assert result.all_messages() == snapshot( [ ModelRequest( @@ -137,7 +137,7 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id='tooluse_5WEci1UmQ8ifMFkUcy2gHQ', ), ], - usage=Usage(requests=1, request_tokens=551, response_tokens=132, total_tokens=683), + usage=RunUsage(requests=1, input_tokens=551, output_tokens=132, total_tokens=683), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -162,7 +162,7 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id='tooluse_9AjloJSaQDKmpPFff-2Clg', ), ], - usage=Usage(requests=1, request_tokens=685, response_tokens=166, total_tokens=851), + usage=RunUsage(requests=1, input_tokens=685, output_tokens=166, total_tokens=851), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -262,7 +262,7 @@ async def get_capital(country: str) -> str: tool_call_id='tooluse_F8LnaCMtQ0-chKTnPhNH2g', ), ], - usage=Usage(requests=1, request_tokens=417, response_tokens=69, total_tokens=486), + usage=RunUsage(requests=1, input_tokens=417, output_tokens=69, total_tokens=486), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -286,7 +286,7 @@ async def get_capital(country: str) -> str: """ ) ], - usage=Usage(requests=1, request_tokens=509, response_tokens=108, total_tokens=617), + usage=RunUsage(requests=1, input_tokens=509, output_tokens=108, total_tokens=617), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -553,7 +553,7 @@ def instructions() -> str: content='The capital of France is Paris. Paris is not only the political and economic hub of the country but also a major center for culture, fashion, art, and tourism. It is renowned for its rich history, iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum, as well as its influence on global culture and cuisine.' ) ], - usage=Usage(requests=1, request_tokens=13, response_tokens=71, total_tokens=84), + usage=RunUsage(requests=1, input_tokens=13, output_tokens=71, total_tokens=84), model_name='us.amazon.nova-pro-v1:0', timestamp=IsDatetime(), ), @@ -603,7 +603,7 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content=IsStr()), ThinkingPart(content=IsStr())], - usage=Usage(requests=1, request_tokens=12, response_tokens=882, total_tokens=894), + usage=RunUsage(requests=1, input_tokens=12, output_tokens=882, total_tokens=894), model_name='us.deepseek.r1-v1:0', timestamp=IsDatetime(), ), @@ -626,7 +626,7 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[IsInstance(TextPart), IsInstance(ThinkingPart)], - usage=Usage(requests=1, request_tokens=12, response_tokens=882, total_tokens=894), + usage=RunUsage(requests=1, input_tokens=12, output_tokens=882, total_tokens=894), model_name='us.deepseek.r1-v1:0', timestamp=IsDatetime(), ), @@ -646,7 +646,7 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p ), IsInstance(TextPart), ], - usage=Usage(requests=1, request_tokens=636, response_tokens=690, total_tokens=1326), + usage=RunUsage(requests=1, input_tokens=636, output_tokens=690, total_tokens=1326), model_name='us.anthropic.claude-3-7-sonnet-20250219-v1:0', timestamp=IsDatetime(), ), diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index c69a8ab261..48f1e4a6c3 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -23,7 +23,7 @@ UserPromptPart, ) from pydantic_ai.tools import RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsInstance, IsNow, raise_if_exception, try_import @@ -102,27 +102,27 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -148,10 +148,10 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=1, - response_tokens=1, + input_tokens=1, + output_tokens=1, total_tokens=2, details={ 'input_tokens': 1, @@ -192,7 +192,7 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -279,7 +279,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, total_tokens=0, details={}), + usage=RunUsage(requests=1, total_tokens=0, details={}), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -301,10 +301,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=5, - response_tokens=3, + input_tokens=5, + output_tokens=3, total_tokens=8, details={'input_tokens': 4, 'output_tokens': 2}, ), @@ -323,17 +323,17 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=3, - request_tokens=5, - response_tokens=3, + input_tokens=5, + output_tokens=3, total_tokens=8, details={'input_tokens': 4, 'output_tokens': 2}, ) @@ -401,10 +401,10 @@ def simple_instructions(ctx: RunContext): content="The capital of France is Paris. It is the country's largest city and serves as the economic, cultural, and political center of France. Paris is known for its rich history, iconic landmarks such as the Eiffel Tower and the Louvre Museum, and its significant influence on fashion, cuisine, and the arts." ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=542, - response_tokens=63, + input_tokens=542, + output_tokens=63, total_tokens=605, details={'input_tokens': 13, 'output_tokens': 61}, ), @@ -447,15 +447,15 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1909, + usage=RunUsage( + input_tokens=13, + output_tokens=1909, total_tokens=1922, details={'reasoning_tokens': 1472, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', + provider_request_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', ), ] ) @@ -476,15 +476,15 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1909, + usage=RunUsage( + input_tokens=13, + output_tokens=1909, total_tokens=1922, details={'reasoning_tokens': 1472, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', + provider_request_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', ), ModelRequest( parts=[ @@ -496,10 +496,10 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key ), ModelResponse( parts=[IsInstance(TextPart)], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=1457, - response_tokens=807, + input_tokens=1457, + output_tokens=807, total_tokens=2264, details={'input_tokens': 954, 'output_tokens': 805}, ), diff --git a/tests/models/test_deepseek.py b/tests/models/test_deepseek.py index 64c73aea1a..b6efc5b1af 100644 --- a/tests/models/test_deepseek.py +++ b/tests/models/test_deepseek.py @@ -19,7 +19,7 @@ ThinkingPartDelta, UserPromptPart, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsStr, try_import @@ -44,10 +44,10 @@ async def test_deepseek_model_thinking_part(allow_model_requests: None, deepseek ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[ThinkingPart(content=IsStr()), TextPart(content=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=12, - response_tokens=789, + input_tokens=12, + output_tokens=789, total_tokens=801, details={ 'prompt_cache_hit_tokens': 0, @@ -58,7 +58,7 @@ async def test_deepseek_model_thinking_part(allow_model_requests: None, deepseek ), model_name='deepseek-reasoner', timestamp=IsDatetime(), - vendor_id='181d9669-2b3a-445e-bd13-2ebff2c378f6', + provider_request_id='181d9669-2b3a-445e-bd13-2ebff2c378f6', ), ] ) diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 89709d4a2e..b894c005fc 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -11,7 +11,7 @@ from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, FunctionModel -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsNow, try_import @@ -60,7 +60,7 @@ def test_first_successful() -> None: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=1, total_tokens=52), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), ), @@ -85,7 +85,7 @@ def test_first_failed() -> None: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=1, total_tokens=52), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), ), @@ -111,7 +111,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=1, total_tokens=52), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), ), @@ -170,19 +170,19 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None [ ModelResponse( parts=[TextPart(content='hello ')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RunUsage(input_tokens=50, output_tokens=1, total_tokens=51), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), @@ -346,19 +346,19 @@ async def test_first_success_streaming() -> None: [ ModelResponse( parts=[TextPart(content='hello ')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RunUsage(input_tokens=50, output_tokens=1, total_tokens=51), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), @@ -375,19 +375,19 @@ async def test_first_failed_streaming() -> None: [ ModelResponse( parts=[TextPart(content='hello ')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RunUsage(input_tokens=50, output_tokens=1, total_tokens=51), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 7d628df31a..2723dd5692 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -55,7 +55,7 @@ ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.tools import ToolDefinition from ..conftest import ClientWithHandler, IsDatetime, IsInstance, IsNow, IsStr, TestEnv, try_import @@ -612,14 +612,14 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3)) result = await agent.run('Hello', message_history=result.new_messages()) assert result.output == 'Hello world' @@ -628,18 +628,18 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -660,10 +660,10 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2, 123]}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -723,10 +723,10 @@ async def get_location(loc_name: str) -> str: parts=[ ToolCallPart(tool_name='get_location', args={'loc_name': 'San Fransisco'}, tool_call_id=IsStr()) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -743,10 +743,10 @@ async def get_location(loc_name: str) -> str: ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), ToolCallPart(tool_name='get_location', args={'loc_name': 'New York'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -766,14 +766,14 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) - assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, total_tokens=9)) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -814,12 +814,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): 'Hello world', ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3)) async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): @@ -851,7 +851,7 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot(['abc', 'abc€def', 'abc€def']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -881,7 +881,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot([(1, 2), (1, 2)]) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3)) async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): @@ -922,7 +922,7 @@ async def bar(y: str) -> str: async with agent.run_stream('Hello') as result: response = await result.get_output() assert response == snapshot((1, 2)) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=2, output_tokens=4, total_tokens=6)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), @@ -931,7 +931,7 @@ async def bar(y: str) -> str: ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()), ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()), ], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -947,7 +947,7 @@ async def bar(y: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -1015,7 +1015,7 @@ def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RunUsage(input_tokens=1, output_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash', timestamp=IsDatetime(), ), @@ -1216,16 +1216,16 @@ async def get_image() -> BinaryContent: ), ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=38, - response_tokens=28, + input_tokens=38, + output_tokens=28, total_tokens=427, details={'thoughts_tokens': 361, 'text_prompt_tokens': 38}, ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1246,16 +1246,16 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The image shows a kiwi fruit, sliced in half.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=360, - response_tokens=11, + input_tokens=360, + output_tokens=11, total_tokens=572, details={'thoughts_tokens': 201, 'text_prompt_tokens': 102, 'image_prompt_tokens': 258}, ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1375,16 +1375,16 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.\n')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=13, - response_tokens=8, + input_tokens=13, + output_tokens=8, total_tokens=21, details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1484,15 +1484,15 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api """ ), ], - usage=Usage( - request_tokens=13, - response_tokens=2028, + usage=RunUsage( + input_tokens=13, + output_tokens=2028, total_tokens=2041, details={'reasoning_tokens': 1664, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', + provider_request_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', ), ] ) @@ -1515,15 +1515,15 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=2028, + usage=RunUsage( + input_tokens=13, + output_tokens=2028, total_tokens=2041, details={'reasoning_tokens': 1664, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', + provider_request_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', ), ModelRequest( parts=[ @@ -1574,16 +1574,16 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api """ ), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=801, - response_tokens=1519, + input_tokens=801, + output_tokens=1519, total_tokens=2320, details={'thoughts_tokens': 794, 'text_prompt_tokens': 801}, ), model_name='gemini-2.5-flash-preview-04-17', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1613,10 +1613,10 @@ async def test_gemini_youtube_video_url_input(allow_model_requests: None, gemini content='The main content of the URL is an analysis of recent 404 HTTP responses. The analysis identifies several patterns, including the most common endpoints with 404 errors, request patterns (such as all requests being GET requests), timeline-related issues, and configuration/authentication problems. The analysis also provides recommendations for addressing the 404 errors.' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=9, - response_tokens=72, + input_tokens=9, + output_tokens=72, total_tokens=81, details={ 'text_prompt_tokens': 9, @@ -1627,7 +1627,7 @@ async def test_gemini_youtube_video_url_input(allow_model_requests: None, gemini ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1665,7 +1665,7 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): result = await agent.run('Test with thought') assert result.output == 'Hello from thought test' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3)) @pytest.mark.vcr() @@ -1693,17 +1693,17 @@ async def bar() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='bar', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=21, - response_tokens=1, + input_tokens=21, + output_tokens=1, total_tokens=22, details={'text_candidates_tokens': 1, 'text_prompt_tokens': 21}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1723,17 +1723,17 @@ async def bar() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=27, - response_tokens=5, + input_tokens=27, + output_tokens=5, total_tokens=32, details={'text_candidates_tokens': 5, 'text_prompt_tokens': 27}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1778,17 +1778,17 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=32, - response_tokens=5, + input_tokens=32, + output_tokens=5, total_tokens=37, details={'text_prompt_tokens': 32, 'text_candidates_tokens': 5}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1808,17 +1808,17 @@ async def get_user_country() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=46, - response_tokens=8, + input_tokens=46, + output_tokens=8, total_tokens=54, details={'text_prompt_tokens': 46, 'text_candidates_tokens': 8}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1870,17 +1870,17 @@ def upcase(text: str) -> str: """ ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=9, - response_tokens=44, + input_tokens=9, + output_tokens=44, total_tokens=598, details={'thoughts_tokens': 545, 'text_prompt_tokens': 9}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id='TT9IaNfGN_DmqtsPzKnE4AE', + provider_details={'finish_reason': 'STOP'}, + provider_request_id='TT9IaNfGN_DmqtsPzKnE4AE', ), ] ) @@ -1940,17 +1940,17 @@ class CityLocation(BaseModel): """ ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=17, - response_tokens=20, + input_tokens=17, + output_tokens=20, total_tokens=37, details={'text_prompt_tokens': 17, 'text_candidates_tokens': 20}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -1999,17 +1999,17 @@ class CountryLanguage(BaseModel): """ ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=46, - response_tokens=46, + input_tokens=46, + output_tokens=46, total_tokens=92, details={'text_prompt_tokens': 46, 'text_candidates_tokens': 46}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2051,17 +2051,17 @@ class CityLocation(BaseModel): content='{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=80, - response_tokens=56, + input_tokens=80, + output_tokens=56, total_tokens=136, details={'text_prompt_tokens': 80, 'text_candidates_tokens': 56}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2105,17 +2105,17 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=123, - response_tokens=12, + input_tokens=123, + output_tokens=12, total_tokens=453, details={'thoughts_tokens': 318, 'text_prompt_tokens': 123}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -2136,17 +2136,17 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=154, - response_tokens=13, + input_tokens=154, + output_tokens=13, total_tokens=261, details={'thoughts_tokens': 94, 'text_prompt_tokens': 154}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2192,17 +2192,17 @@ class CountryLanguage(BaseModel): content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=253, - response_tokens=27, + input_tokens=253, + output_tokens=27, total_tokens=280, details={'text_prompt_tokens': 253, 'text_candidates_tokens': 27}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) diff --git a/tests/models/test_gemini_vertex.py b/tests/models/test_gemini_vertex.py index c437dfb1d9..732dd60bae 100644 --- a/tests/models/test_gemini_vertex.py +++ b/tests/models/test_gemini_vertex.py @@ -19,7 +19,7 @@ VideoUrl, ) from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -142,11 +142,11 @@ async def test_url_input( ), ModelResponse( parts=[TextPart(content=Is(expected_output))], - usage=IsInstance(Usage), + usage=IsInstance(RunUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -179,11 +179,11 @@ async def test_url_input_force_download(allow_model_requests: None) -> None: # ), ModelResponse( parts=[TextPart(content=Is(output))], - usage=IsInstance(Usage), + usage=IsInstance(RunUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 7e1f372bcc..49e8e198ee 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -36,7 +36,7 @@ VideoUrl, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -67,10 +67,10 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP result = await agent.run('Hello!') assert result.output == snapshot('Hello there! How can I help you today?\n') assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=7, - response_tokens=11, + input_tokens=7, + output_tokens=11, total_tokens=18, details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, ) @@ -91,16 +91,16 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP ), ModelResponse( parts=[TextPart(content='Hello there! How can I help you today?\n')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=7, - response_tokens=11, + input_tokens=7, + output_tokens=11, total_tokens=18, details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -131,10 +131,10 @@ async def temperature(city: str, date: datetime.date) -> str: result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response) assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=224, - response_tokens=35, + input_tokens=224, + output_tokens=35, total_tokens=259, details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35}, ) @@ -159,16 +159,16 @@ async def temperature(city: str, date: datetime.date) -> str: tool_name='temperature', args={'date': '2022-01-01', 'city': 'London'}, tool_call_id=IsStr() ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=101, - response_tokens=14, + input_tokens=101, + output_tokens=14, total_tokens=115, details={'text_prompt_tokens': 101, 'text_candidates_tokens': 14}, ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -185,16 +185,16 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id=IsStr(), ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=123, - response_tokens=21, + input_tokens=123, + output_tokens=21, total_tokens=144, details={'text_prompt_tokens': 123, 'text_candidates_tokens': 21}, ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -244,16 +244,16 @@ async def get_capital(country: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_capital', args={'country': 'France'}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=57, - response_tokens=15, + input_tokens=57, + output_tokens=15, total_tokens=227, details={'thoughts_tokens': 155, 'text_prompt_tokens': 57}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -271,16 +271,16 @@ async def get_capital(country: str) -> str: content='I am sorry, I cannot fulfill this request. The country "France" is not supported by my system.' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=104, - response_tokens=22, + input_tokens=104, + output_tokens=22, total_tokens=304, details={'thoughts_tokens': 178, 'text_prompt_tokens': 104}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -546,16 +546,16 @@ def instructions() -> str: ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.\n')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=13, - response_tokens=8, + input_tokens=13, + output_tokens=8, total_tokens=21, details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -634,16 +634,16 @@ async def test_google_model_thinking_part(allow_model_requests: None, google_pro ), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=15, - response_tokens=1041, + input_tokens=15, + output_tokens=1041, total_tokens=2703, details={'thoughts_tokens': 1647, 'text_prompt_tokens': 15}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -792,11 +792,11 @@ async def test_google_url_input( ), ModelResponse( parts=[TextPart(content=Is(expected_output))], - usage=IsInstance(Usage), + usage=IsInstance(RunUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -829,11 +829,11 @@ async def test_google_url_input_force_download(allow_model_requests: None) -> No ), ModelResponse( parts=[TextPart(content=Is(output))], - usage=IsInstance(Usage), + usage=IsInstance(RunUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -875,16 +875,16 @@ async def bar() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='bar', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=21, - response_tokens=1, + input_tokens=21, + output_tokens=1, total_tokens=22, details={'text_candidates_tokens': 1, 'text_prompt_tokens': 21}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -904,16 +904,16 @@ async def bar() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=27, - response_tokens=5, + input_tokens=27, + output_tokens=5, total_tokens=32, details={'text_candidates_tokens': 5, 'text_prompt_tokens': 27}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -968,16 +968,16 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=33, - response_tokens=5, + input_tokens=33, + output_tokens=5, total_tokens=38, details={'text_candidates_tokens': 5, 'text_prompt_tokens': 33}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -997,16 +997,16 @@ async def get_user_country() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=47, - response_tokens=8, + input_tokens=47, + output_tokens=8, total_tokens=55, details={'text_candidates_tokens': 8, 'text_prompt_tokens': 47}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1051,16 +1051,16 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=49, - response_tokens=12, + input_tokens=49, + output_tokens=12, total_tokens=325, details={'thoughts_tokens': 264, 'text_prompt_tokens': 49}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1074,16 +1074,16 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=80, - response_tokens=9, + input_tokens=80, + output_tokens=9, total_tokens=239, details={'thoughts_tokens': 150, 'text_prompt_tokens': 80}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1141,16 +1141,16 @@ class CityLocation(BaseModel): """ ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=25, - response_tokens=20, + input_tokens=25, + output_tokens=20, total_tokens=45, details={'text_candidates_tokens': 20, 'text_prompt_tokens': 25}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1198,16 +1198,16 @@ class CountryLanguage(BaseModel): """ ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=50, - response_tokens=46, + input_tokens=50, + output_tokens=46, total_tokens=96, details={'text_candidates_tokens': 46, 'text_prompt_tokens': 50}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1244,16 +1244,16 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=80, - response_tokens=13, + input_tokens=80, + output_tokens=13, total_tokens=93, details={'text_candidates_tokens': 13, 'text_prompt_tokens': 80}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1296,16 +1296,16 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=123, - response_tokens=12, + input_tokens=123, + output_tokens=12, total_tokens=267, details={'thoughts_tokens': 132, 'text_prompt_tokens': 123}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1326,16 +1326,16 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=154, - response_tokens=13, + input_tokens=154, + output_tokens=13, total_tokens=320, details={'thoughts_tokens': 153, 'text_prompt_tokens': 154}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1380,16 +1380,16 @@ class CountryLanguage(BaseModel): content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=240, - response_tokens=27, + input_tokens=240, + output_tokens=27, total_tokens=267, details={'text_candidates_tokens': 27, 'text_prompt_tokens': 240}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 4c5de38ee7..b7e7547f56 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -34,7 +34,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream @@ -139,31 +139,31 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -213,10 +213,10 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -301,10 +301,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RunUsage(requests=1, input_tokens=2, output_tokens=1, total_tokens=3), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -324,10 +324,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), + usage=RunUsage(requests=1, input_tokens=3, output_tokens=2, total_tokens=6), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -341,10 +341,10 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -449,7 +449,7 @@ async def test_stream_structured(allow_model_requests: None): ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), @@ -579,10 +579,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_wkpd')], - usage=Usage(requests=1, request_tokens=192, response_tokens=8, total_tokens=200), + usage=RunUsage(requests=1, input_tokens=192, output_tokens=8, total_tokens=200), model_name='meta-llama/llama-4-scout-17b-16e-instruct', timestamp=IsDatetime(), - vendor_id='chatcmpl-3c327c89-e9f5-4aac-a5d5-190e6f6f25c9', + provider_request_id='chatcmpl-3c327c89-e9f5-4aac-a5d5-190e6f6f25c9', ), ModelRequest( parts=[ @@ -603,10 +603,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The fruit in the image is a kiwi.')], - usage=Usage(requests=1, request_tokens=2552, response_tokens=11, total_tokens=2563), + usage=RunUsage(requests=1, input_tokens=2552, output_tokens=11, total_tokens=2563), model_name='meta-llama/llama-4-scout-17b-16e-instruct', timestamp=IsDatetime(), - vendor_id='chatcmpl-82dfad42-6a28-4089-82c3-c8633f626c0d', + provider_request_id='chatcmpl-82dfad42-6a28-4089-82c3-c8633f626c0d', ), ] ) @@ -681,10 +681,10 @@ async def test_groq_model_instructions(allow_model_requests: None, groq_api_key: ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage(requests=1, request_tokens=48, response_tokens=8, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=48, output_tokens=8, total_tokens=56), model_name='llama-3.3-70b-versatile', timestamp=IsDatetime(), - vendor_id='chatcmpl-7586b6a9-fb4b-4ec7-86a0-59f0a77844cf', + provider_request_id='chatcmpl-7586b6a9-fb4b-4ec7-86a0-59f0a77844cf', ), ] ) @@ -761,10 +761,10 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key """ ), ], - usage=Usage(requests=1, request_tokens=21, response_tokens=1414, total_tokens=1435), + usage=RunUsage(requests=1, input_tokens=21, output_tokens=1414, total_tokens=1435), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ] ) @@ -782,10 +782,10 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key ), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage(requests=1, request_tokens=21, response_tokens=1414, total_tokens=1435), + usage=RunUsage(requests=1, input_tokens=21, output_tokens=1414, total_tokens=1435), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -899,10 +899,10 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key """ ), ], - usage=Usage(requests=1, request_tokens=524, response_tokens=1590, total_tokens=2114), + usage=RunUsage(requests=1, input_tokens=524, output_tokens=1590, total_tokens=2114), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ] ) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index b7351d574c..1c0e4111bf 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -36,7 +36,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import RunContext @@ -166,10 +166,10 @@ async def test_simple_completion(allow_model_requests: None, huggingface_api_key content='Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' ) ], - usage=Usage(requests=1, request_tokens=30, response_tokens=29, total_tokens=59), + usage=RunUsage(requests=1, input_tokens=30, output_tokens=29, total_tokens=59), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=datetime(2025, 7, 8, 13, 42, 33, tzinfo=timezone.utc), - vendor_id='chatcmpl-d445c0d473a84791af2acf356cc00df7', + provider_request_id='chatcmpl-d445c0d473a84791af2acf356cc00df7', ) @@ -186,7 +186,7 @@ async def test_request_simple_usage(allow_model_requests: None, huggingface_api_ result.output == "Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, or just want to chat, feel free to let me know!" ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=30, response_tokens=40, total_tokens=70)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=30, output_tokens=40, total_tokens=70)) async def test_request_structured_response( @@ -232,10 +232,10 @@ async def test_request_structured_response( tool_call_id='123', ) ], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ) @@ -363,10 +363,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=2), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -386,10 +386,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RunUsage(requests=1, input_tokens=2, output_tokens=1, total_tokens=3), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -403,10 +403,10 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -447,7 +447,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3, total_tokens=9)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -557,7 +557,7 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=20, output_tokens=10, total_tokens=30)) # double check usage matches stream count assert result.usage().response_tokens == len(stream) @@ -616,7 +616,7 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3, total_tokens=9)) @pytest.mark.vcr() @@ -650,10 +650,10 @@ async def test_image_url_input(allow_model_requests: None, huggingface_api_key: ), ModelResponse( parts=[TextPart(content='Hello! How can I assist you with this image of a potato?')], - usage=Usage(requests=1, request_tokens=269, response_tokens=15, total_tokens=284), + usage=RunUsage(requests=1, input_tokens=269, output_tokens=15, total_tokens=284), model_name='Qwen/Qwen2.5-VL-72B-Instruct', timestamp=datetime(2025, 7, 8, 14, 4, 39, tzinfo=timezone.utc), - vendor_id='chatcmpl-49aa100effab4ca28514d5ccc00d7944', + provider_request_id='chatcmpl-49aa100effab4ca28514d5ccc00d7944', ), ] ) @@ -716,10 +716,10 @@ def simple_instructions(ctx: RunContext): ), ModelResponse( parts=[TextPart(content='Paris')], - usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), + usage=RunUsage(requests=1, input_tokens=26, output_tokens=2, total_tokens=28), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=IsDatetime(), - vendor_id='chatcmpl-b3936940372c481b8d886e596dc75524', + provider_request_id='chatcmpl-b3936940372c481b8d886e596dc75524', ), ] ) @@ -808,10 +808,10 @@ def response_validator(value: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='invalid-response')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -825,10 +825,10 @@ def response_validator(value: str) -> str: ), ModelResponse( parts=[TextPart(content='final-response')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -955,10 +955,10 @@ async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_ap IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + usage=RunUsage(requests=1, input_tokens=15, output_tokens=1090, total_tokens=1105), model_name='Qwen/Qwen3-235B-A22B', timestamp=IsDatetime(), - vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + provider_request_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', ), ] ) @@ -978,10 +978,10 @@ async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_ap IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + usage=RunUsage(requests=1, input_tokens=15, output_tokens=1090, total_tokens=1105), model_name='Qwen/Qwen3-235B-A22B', timestamp=IsDatetime(), - vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + provider_request_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', ), ModelRequest( parts=[ @@ -996,10 +996,10 @@ async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_ap IsInstance(ThinkingPart), TextPart(content=IsStr()), ], - usage=Usage(requests=1, request_tokens=691, response_tokens=1860, total_tokens=2551), + usage=RunUsage(requests=1, input_tokens=691, output_tokens=1860, total_tokens=2551), model_name='Qwen/Qwen3-235B-A22B', timestamp=IsDatetime(), - vendor_id='chatcmpl-35fdec1307634f94a39f7e26f52e12a7', + provider_request_id='chatcmpl-35fdec1307634f94a39f7e26f52e12a7', ), ] ) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index a156bb7fa8..f4fa4e1114 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -35,7 +35,7 @@ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel from pydantic_ai.settings import ModelSettings -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsStr, try_import @@ -80,7 +80,7 @@ async def request( TextPart('text2'), {}, # test unexpected parts # type: ignore ], - usage=Usage(request_tokens=100, response_tokens=200), + usage=RunUsage(input_tokens=100, output_tokens=200), model_name='my_model_123', ) @@ -96,7 +96,7 @@ async def request_stream( class MyResponseStream(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - self._usage = Usage(request_tokens=300, response_tokens=400) + self._usage = RunUsage(input_tokens=300, output_tokens=400) maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') if maybe_event is not None: # pragma: no branch yield maybe_event diff --git a/tests/models/test_mcp_sampling.py b/tests/models/test_mcp_sampling.py index 0336ce2f73..1d6ca30a85 100644 --- a/tests/models/test_mcp_sampling.py +++ b/tests/models/test_mcp_sampling.py @@ -10,7 +10,7 @@ from pydantic_ai.agent import Agent from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import BinaryContent, ModelRequest, ModelResponse, SystemPromptPart, TextPart, UserPromptPart -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsNow, try_import @@ -59,7 +59,7 @@ def test_assistant_text(): ), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='test-model', timestamp=IsNow(tz=timezone.utc), ), @@ -93,14 +93,14 @@ def test_assistant_text_history(): ModelRequest(parts=[UserPromptPart(content='1', timestamp=IsNow(tz=timezone.utc))], instructions='testing'), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='test-model', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='2', timestamp=IsNow(tz=timezone.utc))], instructions='testing'), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='test-model', timestamp=IsNow(tz=timezone.utc), ), @@ -121,7 +121,7 @@ def test_assistant_text_history_complex(): ), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='test-model', ), ] diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6736959726..311c88300f 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -29,7 +29,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream @@ -219,18 +219,18 @@ async def test_multiple_completions(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello again')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -272,26 +272,26 @@ async def test_three_completions(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello again')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='final message')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -407,10 +407,10 @@ class CityLocation(BaseModel): tool_call_id='123', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=2, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -468,10 +468,10 @@ class CityLocation(BaseModel): tool_call_id='123', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -528,10 +528,10 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque tool_call_id='123', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1096,10 +1096,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RunUsage(requests=1, input_tokens=2, output_tokens=1, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1119,10 +1119,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), + usage=RunUsage(requests=1, input_tokens=3, output_tokens=2, total_tokens=6), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1136,10 +1136,10 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1239,10 +1239,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RunUsage(requests=1, input_tokens=2, output_tokens=1, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1262,10 +1262,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), + usage=RunUsage(requests=1, input_tokens=3, output_tokens=2, total_tokens=6), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1285,10 +1285,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RunUsage(requests=1, input_tokens=2, output_tokens=1, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1385,7 +1385,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RunUsage(input_tokens=2, output_tokens=2, total_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1401,7 +1401,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"won": true}', tool_call_id='1')], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RunUsage(input_tokens=2, output_tokens=2, total_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1487,7 +1487,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RunUsage(input_tokens=2, output_tokens=2, total_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1503,7 +1503,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(request_tokens=4, response_tokens=4, total_tokens=4), + usage=RunUsage(input_tokens=4, output_tokens=4, total_tokens=4), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1592,7 +1592,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RunUsage(input_tokens=2, output_tokens=2, total_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1614,7 +1614,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1630,7 +1630,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(request_tokens=4, response_tokens=4, total_tokens=4), + usage=RunUsage(input_tokens=4, output_tokens=4, total_tokens=4), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1804,10 +1804,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='utZJMAZN4')], - usage=Usage(requests=1, request_tokens=65, response_tokens=16, total_tokens=81), + usage=RunUsage(requests=1, input_tokens=65, output_tokens=16, total_tokens=81), model_name='pixtral-12b-latest', timestamp=IsDatetime(), - vendor_id='fce6d16a4e5940edb24ae16dd0369947', + provider_request_id='fce6d16a4e5940edb24ae16dd0369947', ), ModelRequest( parts=[ @@ -1832,10 +1832,10 @@ async def get_image() -> BinaryContent: content='The image you\'re referring to, labeled as "file 1c8566," shows a kiwi. Kiwis are small, brown, oval-shaped fruits with a bright green flesh inside that is dotted with tiny black seeds. They have a sweet and tangy flavor and are known for being rich in vitamin C and fiber.' ) ], - usage=Usage(requests=1, request_tokens=2931, response_tokens=70, total_tokens=3001), + usage=RunUsage(requests=1, input_tokens=2931, output_tokens=70, total_tokens=3001), model_name='pixtral-12b-latest', timestamp=IsDatetime(), - vendor_id='26e7de193646460e8904f8e604a60dc1', + provider_request_id='26e7de193646460e8904f8e604a60dc1', ), ] ) @@ -1870,10 +1870,10 @@ async def test_image_url_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1903,10 +1903,10 @@ async def test_image_as_binary_content_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1939,10 +1939,10 @@ async def test_pdf_url_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1969,10 +1969,10 @@ async def test_pdf_as_binary_content_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -2045,10 +2045,10 @@ async def test_mistral_model_instructions(allow_model_requests: None, mistral_ap ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RunUsage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -2072,15 +2072,15 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap ThinkingPart(content=IsStr(), id='rs_68079ad7f0588191af64f067e7314d840493b22e4095129c'), TextPart(content=IsStr()), ], - usage=Usage( - request_tokens=13, - response_tokens=1789, + usage=RunUsage( + input_tokens=13, + output_tokens=1789, total_tokens=1802, details={'reasoning_tokens': 1344, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', + provider_request_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', ), ] ) @@ -2135,15 +2135,15 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap """ ), ], - usage=Usage( - request_tokens=13, - response_tokens=1789, + usage=RunUsage( + input_tokens=13, + output_tokens=1789, total_tokens=1802, details={'reasoning_tokens': 1344, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', + provider_request_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', ), ModelRequest( parts=[ @@ -2155,10 +2155,10 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap ), ModelResponse( parts=[TextPart(content=IsStr())], - usage=Usage(requests=1, request_tokens=1036, response_tokens=691, total_tokens=1727), + usage=RunUsage(requests=1, input_tokens=1036, output_tokens=691, total_tokens=1727), model_name='mistral-large-latest', timestamp=IsDatetime(), - vendor_id='a088e80a476e44edaaa959a1ff08f358', + provider_request_id='a088e80a476e44edaaa959a1ff08f358', ), ] ) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 7e1e820af9..0673eae0d0 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -22,7 +22,7 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from ..conftest import IsNow, IsStr @@ -66,7 +66,7 @@ def test_simple(): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content="content='Hello' part_kind='user-prompt' message_count=1")], - usage=Usage(requests=1, request_tokens=51, response_tokens=3, total_tokens=54), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=3, total_tokens=54), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), ), @@ -80,14 +80,14 @@ def test_simple(): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content="content='Hello' part_kind='user-prompt' message_count=1")], - usage=Usage(requests=1, request_tokens=51, response_tokens=3, total_tokens=54), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=3, total_tokens=54), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content="content='World' part_kind='user-prompt' message_count=3")], - usage=Usage(requests=1, request_tokens=52, response_tokens=6, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=6, total_tokens=58), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), ), @@ -157,7 +157,7 @@ def test_weather(): tool_name='get_location', args='{"location_description": "London"}', tool_call_id=IsStr() ) ], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=5, total_tokens=56), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -173,7 +173,7 @@ def test_weather(): ), ModelResponse( parts=[ToolCallPart(tool_name='get_weather', args='{"lat": 51, "lng": 0}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=56, response_tokens=11, total_tokens=67), + usage=RunUsage(requests=1, input_tokens=56, output_tokens=11, total_tokens=67), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -189,7 +189,7 @@ def test_weather(): ), ModelResponse( parts=[TextPart(content='Raining in London')], - usage=Usage(requests=1, request_tokens=57, response_tokens=14, total_tokens=71), + usage=RunUsage(requests=1, input_tokens=57, output_tokens=14, total_tokens=71), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -357,7 +357,7 @@ def test_call_all(): ToolCallPart(tool_name='qux', args={'x': 0}, tool_call_id=IsStr()), ToolCallPart(tool_name='quz', args={'x': 'a'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=52, response_tokens=21, total_tokens=73), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=21, total_tokens=73), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -382,7 +382,7 @@ def test_call_all(): ), ModelResponse( parts=[TextPart(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')], - usage=Usage(requests=1, request_tokens=57, response_tokens=33, total_tokens=90), + usage=RunUsage(requests=1, input_tokens=57, output_tokens=33, total_tokens=90), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -451,13 +451,13 @@ async def test_stream_text(): ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RunUsage(input_tokens=50, output_tokens=2, total_tokens=52), model_name='function::stream_text_function', timestamp=IsNow(tz=timezone.utc), ), ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=50, response_tokens=2, total_tokens=52)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=50, output_tokens=2, total_tokens=52)) class Foo(BaseModel): @@ -480,10 +480,10 @@ async def stream_structured_function( async with agent.run_stream('') as result: assert await result.get_output() == snapshot(Foo(x=1)) assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=50, - response_tokens=4, + input_tokens=50, + output_tokens=4, total_tokens=54, ) ) diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 02aafd259f..5c5d8e42d2 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -30,7 +30,7 @@ VideoUrl, ) from pydantic_ai.models.test import TestModel, _chars, _JsonSchemaTestData # pyright: ignore[reportPrivateUsage] -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsNow, IsStr @@ -106,7 +106,7 @@ async def my_ret(x: int) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='my_ret', args={'x': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -122,7 +122,7 @@ async def my_ret(x: int) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='my_ret', args={'x': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=61, response_tokens=8, total_tokens=69), + usage=RunUsage(requests=1, input_tokens=61, output_tokens=8, total_tokens=69), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -135,7 +135,7 @@ async def my_ret(x: int) -> str: ), ModelResponse( parts=[TextPart(content='{"my_ret":"1"}')], - usage=Usage(requests=1, request_tokens=62, response_tokens=12, total_tokens=74), + usage=RunUsage(requests=1, input_tokens=62, output_tokens=12, total_tokens=74), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -339,4 +339,4 @@ def test_different_content_input(content: AudioUrl | VideoUrl | ImageUrl | Binar agent = Agent() result = agent.run_sync(['x', content], model=TestModel(custom_output_text='custom')) assert result.output == snapshot('custom') - assert result.usage() == snapshot(Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=1, total_tokens=52)) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 880c70eff3..f27a918244 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -43,7 +43,7 @@ from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -172,31 +172,31 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -231,7 +231,7 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1, total_tokens=3)) async def test_request_structured_response(allow_model_requests: None): @@ -265,10 +265,10 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -355,12 +355,12 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage( - requests=1, request_tokens=2, response_tokens=1, total_tokens=3, details={'cached_tokens': 1} + usage=RunUsage( + requests=1, input_tokens=2, output_tokens=1, total_tokens=3, details={'cached_tokens': 1} ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -380,12 +380,12 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage( - requests=1, request_tokens=3, response_tokens=2, total_tokens=6, details={'cached_tokens': 2} + usage=RunUsage( + requests=1, input_tokens=3, output_tokens=2, total_tokens=6, details={'cached_tokens': 2} ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -399,18 +399,18 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), + usage=RunUsage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=3, - request_tokens=5, - response_tokens=3, + input_tokens=5, + output_tokens=3, total_tokens=9, details={'cached_tokens': 3}, ) @@ -447,7 +447,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=4, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=4, input_tokens=6, output_tokens=3, total_tokens=9)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -519,7 +519,7 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=11, request_tokens=20, response_tokens=10, total_tokens=30)) + assert result.usage() == snapshot(RunUsage(requests=11, input_tokens=20, output_tokens=10, total_tokens=30)) # double check usage matches stream count assert result.usage().response_tokens == len(stream) @@ -668,7 +668,7 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=4, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=4, input_tokens=6, output_tokens=3, total_tokens=9)) @pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None]) @@ -819,10 +819,10 @@ async def get_image() -> ImageUrl: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_4hrT4QP9jfojtK69vGiFCFjG')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=46, - response_tokens=11, + input_tokens=46, + output_tokens=11, total_tokens=57, details={ 'accepted_prediction_tokens': 0, @@ -834,7 +834,7 @@ async def get_image() -> ImageUrl: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmTHlrARTzAHK1na9s80xDlQGYPX', + provider_request_id='chatcmpl-BRmTHlrARTzAHK1na9s80xDlQGYPX', ), ModelRequest( parts=[ @@ -857,10 +857,10 @@ async def get_image() -> ImageUrl: ), ModelResponse( parts=[TextPart(content='The image shows a potato.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=503, - response_tokens=8, + input_tokens=503, + output_tokens=8, total_tokens=511, details={ 'accepted_prediction_tokens': 0, @@ -872,7 +872,7 @@ async def get_image() -> ImageUrl: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmTI0Y2zmkGw27kLarhsmiFQTGxR', + provider_request_id='chatcmpl-BRmTI0Y2zmkGw27kLarhsmiFQTGxR', ), ] ) @@ -902,10 +902,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_Btn0GIzGr4ugNlLmkQghQUMY')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=46, - response_tokens=11, + input_tokens=46, + output_tokens=11, total_tokens=57, details={ 'accepted_prediction_tokens': 0, @@ -917,7 +917,7 @@ async def get_image() -> BinaryContent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlkLhPc87BdohVobEJJCGq3rUAG2', + provider_request_id='chatcmpl-BRlkLhPc87BdohVobEJJCGq3rUAG2', ), ModelRequest( parts=[ @@ -938,10 +938,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The image shows a kiwi fruit.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=1185, - response_tokens=9, + input_tokens=1185, + output_tokens=9, total_tokens=1194, details={ 'accepted_prediction_tokens': 0, @@ -953,7 +953,7 @@ async def get_image() -> BinaryContent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlkORPA5rXMV3uzcOcgK4eQFKCVW', + provider_request_id='chatcmpl-BRlkORPA5rXMV3uzcOcgK4eQFKCVW', ), ] ) @@ -1583,10 +1583,10 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=24, - response_tokens=8, + input_tokens=24, + output_tokens=8, total_tokens=32, details={ 'accepted_prediction_tokens': 0, @@ -1598,7 +1598,7 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BJjf61mLb9z5H45ClJzbx0UWKwjo1', + provider_request_id='chatcmpl-BJjf61mLb9z5H45ClJzbx0UWKwjo1', ), ] ) @@ -1630,10 +1630,10 @@ async def get_temperature(city: str) -> float: ), ModelResponse( parts=[ToolCallPart(tool_name='get_temperature', args='{"city":"Tokyo"}', tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=50, - response_tokens=15, + input_tokens=50, + output_tokens=15, total_tokens=65, details={ 'accepted_prediction_tokens': 0, @@ -1645,7 +1645,7 @@ async def get_temperature(city: str) -> float: ), model_name='gpt-4.1-mini-2025-04-14', timestamp=IsDatetime(), - vendor_id='chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq', + provider_request_id='chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq', ), ModelRequest( parts=[ @@ -1657,10 +1657,10 @@ async def get_temperature(city: str) -> float: ), ModelResponse( parts=[TextPart(content='The temperature in Tokyo is currently 20.0 degrees Celsius.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=75, - response_tokens=15, + input_tokens=75, + output_tokens=15, total_tokens=90, details={ 'accepted_prediction_tokens': 0, @@ -1672,7 +1672,7 @@ async def get_temperature(city: str) -> float: ), model_name='gpt-4.1-mini-2025-04-14', timestamp=IsDatetime(), - vendor_id='chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP', + provider_request_id='chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP', ), ] ) @@ -1696,15 +1696,15 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034841ab2881918a8c210e3d988b9208c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=2050, + usage=RunUsage( + input_tokens=13, + output_tokens=2050, total_tokens=2063, details={'reasoning_tokens': 1664, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', + provider_request_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', ), ] ) @@ -1724,15 +1724,15 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034841ab2881918a8c210e3d988b9208c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=2050, + usage=RunUsage( + input_tokens=13, + output_tokens=2050, total_tokens=2063, details={'reasoning_tokens': 1664, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', + provider_request_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', ), ModelRequest( parts=[ @@ -1749,15 +1749,15 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034858dc588191bc3a6801c23e728f08c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=Usage( - request_tokens=424, - response_tokens=2033, + usage=RunUsage( + input_tokens=424, + output_tokens=2033, total_tokens=2457, details={'reasoning_tokens': 1408, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_6803484f19a88191b9ea975d7cfbbe8408c845d2be9bcdd8', + provider_request_id='resp_6803484f19a88191b9ea975d7cfbbe8408c845d2be9bcdd8', ), ] ) @@ -1782,15 +1782,15 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1900, + usage=RunUsage( + input_tokens=13, + output_tokens=1900, total_tokens=1913, details={'reasoning_tokens': 1536, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', + provider_request_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', ), ] ) @@ -1811,15 +1811,15 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1900, + usage=RunUsage( + input_tokens=13, + output_tokens=1900, total_tokens=1913, details={'reasoning_tokens': 1536, 'cached_tokens': 0}, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', + provider_request_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', ), ModelRequest( parts=[ @@ -1831,10 +1831,10 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api ), ModelResponse( parts=[TextPart(content=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=822, - response_tokens=2437, + input_tokens=822, + output_tokens=2437, total_tokens=3259, details={ 'accepted_prediction_tokens': 0, @@ -1846,7 +1846,7 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='chatcmpl-BP7ocN6qxho4C1UzUJWnU5tPJno55', + provider_request_id='chatcmpl-BP7ocN6qxho4C1UzUJWnU5tPJno55', ), ] ) @@ -2094,10 +2094,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=68, - response_tokens=12, + input_tokens=68, + output_tokens=12, total_tokens=80, details={ 'accepted_prediction_tokens': 0, @@ -2109,7 +2109,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', + provider_request_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', ), ModelRequest( parts=[ @@ -2129,10 +2129,10 @@ async def get_user_country() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=89, - response_tokens=36, + input_tokens=89, + output_tokens=36, total_tokens=125, details={ 'accepted_prediction_tokens': 0, @@ -2144,7 +2144,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', + provider_request_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', ), ModelRequest( parts=[ @@ -2190,10 +2190,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=42, - response_tokens=11, + input_tokens=42, + output_tokens=11, total_tokens=53, details={ 'accepted_prediction_tokens': 0, @@ -2205,7 +2205,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8', + provider_request_id='chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8', ), ModelRequest( parts=[ @@ -2219,10 +2219,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=63, - response_tokens=10, + input_tokens=63, + output_tokens=10, total_tokens=73, details={ 'accepted_prediction_tokens': 0, @@ -2234,7 +2234,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH', + provider_request_id='chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH', ), ] ) @@ -2273,10 +2273,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=71, - response_tokens=12, + input_tokens=71, + output_tokens=12, total_tokens=83, details={ 'accepted_prediction_tokens': 0, @@ -2288,7 +2288,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', + provider_request_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', ), ModelRequest( parts=[ @@ -2302,10 +2302,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=92, - response_tokens=15, + input_tokens=92, + output_tokens=15, total_tokens=107, details={ 'accepted_prediction_tokens': 0, @@ -2317,7 +2317,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', + provider_request_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', ), ] ) @@ -2358,10 +2358,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_SIttSeiOistt33Htj4oiHOOX') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=160, - response_tokens=11, + input_tokens=160, + output_tokens=11, total_tokens=171, details={ 'accepted_prediction_tokens': 0, @@ -2373,7 +2373,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD', + provider_request_id='chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD', ), ModelRequest( parts=[ @@ -2391,10 +2391,10 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=181, - response_tokens=25, + input_tokens=181, + output_tokens=25, total_tokens=206, details={ 'accepted_prediction_tokens': 0, @@ -2406,7 +2406,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd', + provider_request_id='chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd', ), ] ) @@ -2450,10 +2450,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=109, - response_tokens=11, + input_tokens=109, + output_tokens=11, total_tokens=120, details={ 'accepted_prediction_tokens': 0, @@ -2465,7 +2465,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw', + provider_request_id='chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw', ), ModelRequest( parts=[ @@ -2486,10 +2486,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=130, - response_tokens=11, + input_tokens=130, + output_tokens=11, total_tokens=141, details={ 'accepted_prediction_tokens': 0, @@ -2501,7 +2501,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0', + provider_request_id='chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0', ), ] ) @@ -2549,10 +2549,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09') ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=273, - response_tokens=11, + input_tokens=273, + output_tokens=11, total_tokens=284, details={ 'accepted_prediction_tokens': 0, @@ -2564,7 +2564,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR', + provider_request_id='chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR', ), ModelRequest( parts=[ @@ -2589,10 +2589,10 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=294, - response_tokens=21, + input_tokens=294, + output_tokens=21, total_tokens=315, details={ 'accepted_prediction_tokens': 0, @@ -2604,7 +2604,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC', + provider_request_id='chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC', ), ] ) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 30e98d70fe..836f876757 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -23,7 +23,7 @@ from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.profiles.openai import openai_model_profile from pydantic_ai.tools import ToolDefinition -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from ..conftest import IsDatetime, IsStr, TestEnv, try_import @@ -190,15 +190,15 @@ async def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=Usage( - request_tokens=0, - response_tokens=0, + usage=RunUsage( + input_tokens=0, + output_tokens=0, total_tokens=0, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67e547c48c9481918c5c4394464ce0c60ae6111e84dd5c08', + provider_request_id='resp_67e547c48c9481918c5c4394464ce0c60ae6111e84dd5c08', ), ModelRequest( parts=[ @@ -226,15 +226,15 @@ async def get_location(loc_name: str) -> str: """ ) ], - usage=Usage( - request_tokens=335, - response_tokens=44, + usage=RunUsage( + input_tokens=335, + output_tokens=44, total_tokens=379, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67e547c5a2f08191802a1f43620f348503a2086afed73b47', + provider_request_id='resp_67e547c5a2f08191802a1f43620f348503a2086afed73b47', ), ] ) @@ -264,15 +264,15 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=40, - response_tokens=11, + usage=RunUsage( + input_tokens=40, + output_tokens=11, total_tokens=51, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_681134d3aa3481919ca581a267db1e510fe7a5a4e2123dc3', + provider_request_id='resp_681134d3aa3481919ca581a267db1e510fe7a5a4e2123dc3', ), ModelRequest( parts=[ @@ -293,15 +293,15 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The fruit in the image is a kiwi.')], - usage=Usage( - request_tokens=1185, - response_tokens=11, + usage=RunUsage( + input_tokens=1185, + output_tokens=11, total_tokens=1196, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_681134d53c48819198ce7b89db78dffd02cbfeaababb040c', + provider_request_id='resp_681134d53c48819198ce7b89db78dffd02cbfeaababb040c', ), ] ) @@ -429,15 +429,15 @@ async def test_openai_responses_model_builtin_tools(allow_model_requests: None, """ ) ], - usage=Usage( - request_tokens=320, - response_tokens=200, + usage=RunUsage( + input_tokens=320, + output_tokens=200, total_tokens=520, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67ebcbb93728819197f923ff16e98bce04f5055a2a33abc3', + provider_request_id='resp_67ebcbb93728819197f923ff16e98bce04f5055a2a33abc3', ), ] ) @@ -457,15 +457,15 @@ async def test_openai_responses_model_instructions(allow_model_requests: None, o ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - request_tokens=24, - response_tokens=8, + usage=RunUsage( + input_tokens=24, + output_tokens=8, total_tokens=32, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67f3fdfd9fa08191a3d5825db81b8df6003bc73febb56d77', + provider_request_id='resp_67f3fdfd9fa08191a3d5825db81b8df6003bc73febb56d77', ), ] ) @@ -550,15 +550,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=62, - response_tokens=12, + usage=RunUsage( + input_tokens=62, + output_tokens=12, total_tokens=74, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7', + provider_request_id='resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7', ), ModelRequest( parts=[ @@ -578,15 +578,15 @@ async def get_user_country() -> str: tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', ) ], - usage=Usage( - request_tokens=85, - response_tokens=20, + usage=RunUsage( + input_tokens=85, + output_tokens=20, total_tokens=105, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560', + provider_request_id='resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560', ), ModelRequest( parts=[ @@ -632,15 +632,15 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr') ], - usage=Usage( - request_tokens=36, - response_tokens=12, + usage=RunUsage( + input_tokens=36, + output_tokens=12, total_tokens=48, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a', + provider_request_id='resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a', ), ModelRequest( parts=[ @@ -654,15 +654,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=Usage( - request_tokens=59, - response_tokens=11, + usage=RunUsage( + input_tokens=59, + output_tokens=11, total_tokens=70, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d', + provider_request_id='resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d', ), ] ) @@ -699,15 +699,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=66, - response_tokens=12, + usage=RunUsage( + input_tokens=66, + output_tokens=12, total_tokens=78, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329', + provider_request_id='resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329', ), ModelRequest( parts=[ @@ -721,15 +721,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( - request_tokens=89, - response_tokens=16, + usage=RunUsage( + input_tokens=89, + output_tokens=16, total_tokens=105, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0fde708192989000a62809c6e5020197534e39cc1f', + provider_request_id='resp_68477f0fde708192989000a62809c6e5020197534e39cc1f', ), ] ) @@ -768,15 +768,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=153, - response_tokens=12, + usage=RunUsage( + input_tokens=153, + output_tokens=12, total_tokens=165, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563', + provider_request_id='resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563', ), ModelRequest( parts=[ @@ -794,15 +794,15 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( - request_tokens=176, - response_tokens=26, + usage=RunUsage( + input_tokens=176, + output_tokens=26, total_tokens=202, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f119830819da162aa6e10552035061ad97e2eef7871', + provider_request_id='resp_68477f119830819da162aa6e10552035061ad97e2eef7871', ), ] ) @@ -844,15 +844,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=107, - response_tokens=12, + usage=RunUsage( + input_tokens=107, + output_tokens=12, total_tokens=119, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50', + provider_request_id='resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50', ), ModelRequest( parts=[ @@ -873,15 +873,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( - request_tokens=130, - response_tokens=12, + usage=RunUsage( + input_tokens=130, + output_tokens=12, total_tokens=142, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115', + provider_request_id='resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115', ), ] ) @@ -927,15 +927,15 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=283, - response_tokens=12, + usage=RunUsage( + input_tokens=283, + output_tokens=12, total_tokens=295, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee', + provider_request_id='resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee', ), ModelRequest( parts=[ @@ -960,15 +960,15 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( - request_tokens=306, - response_tokens=22, + usage=RunUsage( + input_tokens=306, + output_tokens=22, total_tokens=328, details={'reasoning_tokens': 0, 'cached_tokens': 0}, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024', + provider_request_id='resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024', ), ] ) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index c13d1da54d..6637add13e 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -19,7 +19,7 @@ UserPromptPart, ) from pydantic_ai.models.function import AgentInfo, FunctionModel -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from .conftest import IsDatetime, IsStr, try_import @@ -612,7 +612,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr() ) ], - usage=Usage(requests=1, request_tokens=52, response_tokens=7, total_tokens=59), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=7, total_tokens=59), model_name='function:track_messages:', timestamp=IsDatetime(), ), diff --git a/tests/test_agent.py b/tests/test_agent.py index a5a9285d52..a9c6831715 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -44,7 +44,7 @@ from pydantic_ai.models.test import TestModel from pydantic_ai.output import StructuredDict, ToolOutput from pydantic_ai.profiles import ModelProfile -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.combined import CombinedToolset from pydantic_ai.toolsets.function import FunctionToolset @@ -107,7 +107,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=7, total_tokens=58), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -130,7 +130,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=87, response_tokens=14, total_tokens=101), + usage=RunUsage(requests=1, input_tokens=87, output_tokens=14, total_tokens=101), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -230,7 +230,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 41, "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=7, total_tokens=58), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -246,7 +246,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=63, response_tokens=14, total_tokens=77), + usage=RunUsage(requests=1, input_tokens=63, output_tokens=14, total_tokens=77), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -288,7 +288,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=1, total_tokens=52), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), ), @@ -305,7 +305,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: parts=[ ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr()) ], - usage=Usage(requests=1, request_tokens=74, response_tokens=8, total_tokens=82), + usage=RunUsage(requests=1, input_tokens=74, output_tokens=8, total_tokens=82), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), ), @@ -853,7 +853,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=53, response_tokens=7, total_tokens=60), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=7, total_tokens=60), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -875,7 +875,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=68, response_tokens=13, total_tokens=81), + usage=RunUsage(requests=1, input_tokens=68, output_tokens=13, total_tokens=81), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -929,7 +929,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='New York City')], - usage=Usage(requests=1, request_tokens=53, response_tokens=3, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=3, total_tokens=56), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -944,7 +944,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='Mexico City')], - usage=Usage(requests=1, request_tokens=70, response_tokens=5, total_tokens=75), + usage=RunUsage(requests=1, input_tokens=70, output_tokens=5, total_tokens=75), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1117,7 +1117,7 @@ def say_world(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=1, total_tokens=52), model_name='function:say_world:', timestamp=IsDatetime(), ), @@ -1177,7 +1177,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=52, response_tokens=6, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=6, total_tokens=58), model_name='function:call_handoff_tool:', timestamp=IsDatetime(), ), @@ -1212,7 +1212,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=57, response_tokens=6, total_tokens=63), + usage=RunUsage(requests=1, input_tokens=57, output_tokens=6, total_tokens=63), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1425,7 +1425,7 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63), + usage=RunUsage(requests=1, input_tokens=56, output_tokens=7, total_tokens=63), model_name='function:return_city_location:', timestamp=IsDatetime(), ), @@ -1464,7 +1464,7 @@ class Foo(BaseModel): ), ModelResponse( parts=[TextPart(content='{"bar":"baz"}')], - usage=Usage(requests=1, request_tokens=56, response_tokens=4, total_tokens=60), + usage=RunUsage(requests=1, input_tokens=56, output_tokens=4, total_tokens=60), model_name='function:return_foo:', timestamp=IsDatetime(), ), @@ -1538,7 +1538,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: content='{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' ) ], - usage=Usage(requests=1, request_tokens=53, response_tokens=17, total_tokens=70), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=17, total_tokens=70), model_name='function:return_foo_bar:', timestamp=IsDatetime(), ), @@ -1579,7 +1579,7 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], - usage=Usage(requests=1, request_tokens=56, response_tokens=5, total_tokens=61), + usage=RunUsage(requests=1, input_tokens=56, output_tokens=5, total_tokens=61), model_name='function:return_city_location:', timestamp=IsDatetime(), ), @@ -1601,7 +1601,7 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage(requests=1, request_tokens=85, response_tokens=12, total_tokens=97), + usage=RunUsage(requests=1, input_tokens=85, output_tokens=12, total_tokens=97), model_name='function:return_city_location:', timestamp=IsDatetime(), ), @@ -1662,7 +1662,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='{"city": "New York City"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=6, total_tokens=59), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=6, total_tokens=59), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1677,7 +1677,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], - usage=Usage(requests=1, request_tokens=70, response_tokens=11, total_tokens=81), + usage=RunUsage(requests=1, input_tokens=70, output_tokens=11, total_tokens=81), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1705,7 +1705,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=5, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1718,7 +1718,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=9, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1737,7 +1737,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=5, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1750,14 +1750,14 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=9, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68), + usage=RunUsage(requests=1, input_tokens=55, output_tokens=13, total_tokens=68), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1767,7 +1767,7 @@ async def ret_a(x: str) -> str: assert result2.output == snapshot('{"ret_a":"a-apple"}') assert result2._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] assert result2.usage() == snapshot( - Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) + RunUsage(requests=1, input_tokens=55, output_tokens=13, total_tokens=68, details=None) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( @@ -1796,7 +1796,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=5, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1809,14 +1809,14 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=9, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68), + usage=RunUsage(requests=1, input_tokens=55, output_tokens=13, total_tokens=68), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1826,7 +1826,7 @@ async def ret_a(x: str) -> str: assert result3.output == snapshot('{"ret_a":"a-apple"}') assert result3._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] assert result3.usage() == snapshot( - Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) + RunUsage(requests=1, input_tokens=55, output_tokens=13, total_tokens=68, details=None) ) @@ -1853,7 +1853,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=5, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1872,7 +1872,7 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=9, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1900,7 +1900,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=5, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1913,7 +1913,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=9, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1935,7 +1935,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72), + usage=RunUsage(requests=1, input_tokens=59, output_tokens=13, total_tokens=72), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1955,7 +1955,7 @@ async def ret_a(x: str) -> str: assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] assert result2.usage() == snapshot( - Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None) + RunUsage(requests=1, input_tokens=59, output_tokens=13, total_tokens=72, details=None) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( @@ -1995,7 +1995,7 @@ def test_run_with_history_and_no_user_prompt(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55), model_name='test', timestamp=IsDatetime(), ), @@ -2027,7 +2027,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=2, total_tokens=53), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=2, total_tokens=53), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2043,7 +2043,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=65, response_tokens=4, total_tokens=69), + usage=RunUsage(requests=1, input_tokens=65, output_tokens=4, total_tokens=69), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2067,7 +2067,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=2, total_tokens=53), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=2, total_tokens=53), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2083,7 +2083,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=65, response_tokens=3, total_tokens=68), + usage=RunUsage(requests=1, input_tokens=65, output_tokens=3, total_tokens=68), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2354,7 +2354,7 @@ def another_tool(y: int) -> int: ToolCallPart(tool_name='final_result', args={'value': 'second'}, tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=53, response_tokens=23, total_tokens=76), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=23, total_tokens=76), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2443,7 +2443,7 @@ def another_tool(y: int) -> int: # pragma: no cover ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=58, response_tokens=18, total_tokens=76), + usage=RunUsage(requests=1, input_tokens=58, output_tokens=18, total_tokens=76), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2595,7 +2595,7 @@ async def get_location(loc_name: str) -> str: TextPart(content='foo'), ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=51, response_tokens=6, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=6, total_tokens=57), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2611,7 +2611,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1, request_tokens=56, response_tokens=8, total_tokens=64), + usage=RunUsage(requests=1, input_tokens=56, output_tokens=8, total_tokens=64), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2642,7 +2642,7 @@ def test_nested_capture_run_messages() -> None: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -2665,7 +2665,7 @@ def test_double_capture_run_messages() -> None: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -2715,7 +2715,7 @@ async def func() -> str: ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=4, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2743,7 +2743,7 @@ async def func() -> str: ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=4, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2754,7 +2754,7 @@ async def func() -> str: ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=54, response_tokens=8, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=8, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2796,7 +2796,7 @@ async def func(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=4, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2825,7 +2825,7 @@ async def func(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=4, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2836,7 +2836,7 @@ async def func(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=54, response_tokens=8, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=8, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2863,7 +2863,7 @@ async def foobar(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=5, total_tokens=56), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -2879,7 +2879,7 @@ async def foobar(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"foobar":"inner agent result"}')], - usage=Usage(requests=1, request_tokens=54, response_tokens=11, total_tokens=65), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=11, total_tokens=65), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -3168,7 +3168,7 @@ def test_instructions_with_message_history(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=56, response_tokens=4, total_tokens=60), + usage=RunUsage(requests=1, input_tokens=56, output_tokens=4, total_tokens=60), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -3223,7 +3223,7 @@ def my_tool(x: int) -> int: TextPart(content='foo'), ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=5, total_tokens=56), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3239,7 +3239,7 @@ def my_tool(x: int) -> int: TextPart(content='bar'), ToolCallPart(tool_name='my_tool', args={'x': 2}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=52, response_tokens=10, total_tokens=62), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=10, total_tokens=62), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3252,7 +3252,7 @@ def my_tool(x: int) -> int: ), ModelResponse( parts=[], - usage=Usage(requests=1, request_tokens=53, response_tokens=10, total_tokens=63), + usage=RunUsage(requests=1, input_tokens=53, output_tokens=10, total_tokens=63), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3391,7 +3391,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), ), ], - usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=4, total_tokens=58), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3416,7 +3416,7 @@ def analyze_data() -> ToolReturn: ), ModelResponse( parts=[TextPart(content='Analysis completed')], - usage=Usage(requests=1, request_tokens=70, response_tokens=6, total_tokens=76), + usage=RunUsage(requests=1, input_tokens=70, output_tokens=6, total_tokens=76), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3466,7 +3466,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), ), ], - usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=4, total_tokens=58), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3483,7 +3483,7 @@ def analyze_data() -> ToolReturn: ), ModelResponse( parts=[TextPart(content='Analysis completed')], - usage=Usage(requests=1, request_tokens=58, response_tokens=6, total_tokens=64), + usage=RunUsage(requests=1, input_tokens=58, output_tokens=6, total_tokens=64), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3735,7 +3735,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[ToolCallPart(tool_name='add_foo_tool', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=57, response_tokens=2, total_tokens=59), + usage=RunUsage(requests=1, input_tokens=57, output_tokens=2, total_tokens=59), model_name='function:respond:', timestamp=IsDatetime(), ), @@ -3751,7 +3751,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[ToolCallPart(tool_name='foo', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=60, response_tokens=4, total_tokens=64), + usage=RunUsage(requests=1, input_tokens=60, output_tokens=4, total_tokens=64), model_name='function:respond:', timestamp=IsDatetime(), ), @@ -3767,7 +3767,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='Done')], - usage=Usage(requests=1, request_tokens=63, response_tokens=5, total_tokens=68), + usage=RunUsage(requests=1, input_tokens=63, output_tokens=5, total_tokens=68), model_name='function:respond:', timestamp=IsDatetime(), ), @@ -3826,7 +3826,7 @@ async def only_if_plan_presented( tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=5, total_tokens=56), model_name='test', timestamp=IsDatetime(), ), @@ -3848,7 +3848,7 @@ async def only_if_plan_presented( tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64), + usage=RunUsage(requests=1, input_tokens=52, output_tokens=12, total_tokens=64), model_name='test', timestamp=IsDatetime(), ), diff --git a/tests/test_direct.py b/tests/test_direct.py index e9a131ea33..cc45f71b15 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -30,7 +30,7 @@ from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from .conftest import IsNow, IsStr @@ -44,7 +44,7 @@ async def test_model_request(): parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55), ) ) @@ -63,7 +63,7 @@ async def test_model_request_tool_call(): parts=[ToolCallPart(tool_name='tool_name', args={}, tool_call_id=IsStr(regex='pyd_ai_.*'))], model_name='test', timestamp=IsNow(tz=timezone.utc), - usage=Usage(requests=1, request_tokens=51, response_tokens=2, total_tokens=53), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=2, total_tokens=53), ) ) @@ -75,7 +75,7 @@ def test_model_request_sync(): parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55), ) ) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 1a1e2ffa37..1cd6265750 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -8,7 +8,7 @@ from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, TextPart, UserPromptPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.tools import RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from .conftest import IsDatetime @@ -55,7 +55,7 @@ def no_op_history_processor(messages: list[ModelMessage]) -> list[ModelMessage]: ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], - usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=4, total_tokens=58), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), @@ -94,7 +94,7 @@ def capture_messages_processor(messages: list[ModelMessage]) -> list[ModelMessag ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], - usage=Usage(requests=1, request_tokens=54, response_tokens=2, total_tokens=56), + usage=RunUsage(requests=1, input_tokens=54, output_tokens=2, total_tokens=56), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 1021b31512..93d0430588 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -28,7 +28,7 @@ from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from .conftest import IsDatetime, IsNow, IsStr, try_import @@ -67,7 +67,7 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: @pytest.fixture def run_context(model: Model) -> RunContext[int]: - return RunContext(deps=0, model=model, usage=Usage()) + return RunContext(deps=0, model=model, usage=RunUsage()) async def test_stdio_server(run_context: RunContext[int]): @@ -201,10 +201,10 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) tool_call_id='call_QssdxTGkPblTYHmyVES1tKBj', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=195, - response_tokens=19, + input_tokens=195, + output_tokens=19, total_tokens=214, details={ 'accepted_prediction_tokens': 0, @@ -216,7 +216,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlnvvqIPFofAtKqtQKMWZkgXhzlT', + provider_request_id='chatcmpl-BRlnvvqIPFofAtKqtQKMWZkgXhzlT', ), ModelRequest( parts=[ @@ -230,10 +230,10 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ), ModelResponse( parts=[TextPart(content='0 degrees Celsius is equal to 32 degrees Fahrenheit.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=227, - response_tokens=13, + input_tokens=227, + output_tokens=13, total_tokens=240, details={ 'accepted_prediction_tokens': 0, @@ -245,7 +245,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlnyjUo5wlyqvdNdM5I8vIWjo1qF', + provider_request_id='chatcmpl-BRlnyjUo5wlyqvdNdM5I8vIWjo1qF', ), ] ) @@ -337,10 +337,10 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): tool_call_id='call_m9goNwaHBbU926w47V7RtWPt', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=194, - response_tokens=18, + input_tokens=194, + output_tokens=18, total_tokens=212, details={ 'accepted_prediction_tokens': 0, @@ -352,7 +352,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlo3e1Ud2lnvkddMilmwC7LAemiy', + provider_request_id='chatcmpl-BRlo3e1Ud2lnvkddMilmwC7LAemiy', ), ModelRequest( parts=[ @@ -370,10 +370,10 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): content='The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=234, - response_tokens=19, + input_tokens=234, + output_tokens=19, total_tokens=253, details={ 'accepted_prediction_tokens': 0, @@ -385,7 +385,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlo41LxqBYgGKWgGrQn67fQacOLp', + provider_request_id='chatcmpl-BRlo41LxqBYgGKWgGrQn67fQacOLp', ), ] ) @@ -414,10 +414,10 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A tool_call_id='call_LaiWltzI39sdquflqeuF0EyE', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=200, - response_tokens=12, + input_tokens=200, + output_tokens=12, total_tokens=212, details={ 'accepted_prediction_tokens': 0, @@ -429,7 +429,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo', + provider_request_id='chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo', ), ModelRequest( parts=[ @@ -443,10 +443,10 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ), ModelResponse( parts=[TextPart(content='The product name is "Pydantic AI".')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=224, - response_tokens=12, + input_tokens=224, + output_tokens=12, total_tokens=236, details={ 'accepted_prediction_tokens': 0, @@ -458,7 +458,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR', + provider_request_id='chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR', ), ] ) @@ -487,10 +487,10 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age tool_call_id='call_qi5GtBeIEyT7Y3yJvVFIi062', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=305, - response_tokens=12, + input_tokens=305, + output_tokens=12, total_tokens=317, details={ 'accepted_prediction_tokens': 0, @@ -502,7 +502,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdHSFe0EykAOpf0LWZzsWAodIQzb', + provider_request_id='chatcmpl-BwdHSFe0EykAOpf0LWZzsWAodIQzb', ), ModelRequest( parts=[ @@ -516,10 +516,10 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age ), ModelResponse( parts=[TextPart(content='The product name is "Pydantic AI".')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=332, - response_tokens=11, + input_tokens=332, + output_tokens=11, total_tokens=343, details={ 'accepted_prediction_tokens': 0, @@ -531,7 +531,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdHTIlBZWzXJPBR8VTOdC4O57ZQA', + provider_request_id='chatcmpl-BwdHTIlBZWzXJPBR8VTOdC4O57ZQA', ), ] ) @@ -562,10 +562,10 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: tool_call_id='call_nFsDHYDZigO0rOHqmChZ3pmt', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=191, - response_tokens=12, + input_tokens=191, + output_tokens=12, total_tokens=203, details={ 'accepted_prediction_tokens': 0, @@ -577,7 +577,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlo7KYJVXuNZ5lLLdYcKZDsX2CHb', + provider_request_id='chatcmpl-BRlo7KYJVXuNZ5lLLdYcKZDsX2CHb', ), ModelRequest( parts=[ @@ -596,10 +596,10 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: content='This is an image of a sliced kiwi with a vibrant green interior and black seeds.' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=1332, - response_tokens=19, + input_tokens=1332, + output_tokens=19, total_tokens=1351, details={ 'accepted_prediction_tokens': 0, @@ -611,7 +611,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloBGHh27w3fQKwxq4fX2cPuZJa9', + provider_request_id='chatcmpl-BRloBGHh27w3fQKwxq4fX2cPuZJa9', ), ] ) @@ -644,10 +644,10 @@ async def test_tool_returning_image_resource_link( tool_call_id='call_eVFgn54V9Nuh8Y4zvuzkYjUp', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=305, - response_tokens=12, + input_tokens=305, + output_tokens=12, total_tokens=317, details={ 'accepted_prediction_tokens': 0, @@ -659,7 +659,7 @@ async def test_tool_returning_image_resource_link( ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdHygYePH1mZgHo2Xxzib0Y7sId7', + provider_request_id='chatcmpl-BwdHygYePH1mZgHo2Xxzib0Y7sId7', ), ModelRequest( parts=[ @@ -678,10 +678,10 @@ async def test_tool_returning_image_resource_link( content='This is an image of a sliced kiwi fruit. It shows the green, seed-speckled interior with fuzzy brown skin around the edges.' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=1452, - response_tokens=29, + input_tokens=1452, + output_tokens=29, total_tokens=1481, details={ 'accepted_prediction_tokens': 0, @@ -693,7 +693,7 @@ async def test_tool_returning_image_resource_link( ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdI2D2r9dvqq3pbsA0qgwKDEdTtD', + provider_request_id='chatcmpl-BwdI2D2r9dvqq3pbsA0qgwKDEdTtD', ), ] ) @@ -714,16 +714,16 @@ async def test_tool_returning_audio_resource( ), ModelResponse( parts=[ToolCallPart(tool_name='get_audio_resource', args={}, tool_call_id=IsStr())], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=383, - response_tokens=12, + input_tokens=383, + output_tokens=12, total_tokens=520, details={'thoughts_tokens': 125, 'text_prompt_tokens': 383}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -738,16 +738,16 @@ async def test_tool_returning_audio_resource( ), ModelResponse( parts=[TextPart(content='The audio resource contains a voice saying "Hello, my name is Marcelo."')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=575, - response_tokens=15, + input_tokens=575, + output_tokens=15, total_tokens=590, details={'text_prompt_tokens': 431, 'audio_prompt_tokens': 144}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -778,16 +778,16 @@ async def test_tool_returning_audio_resource_link( ), ToolCallPart(tool_name='get_audio_resource_link', args={}, tool_call_id=IsStr()), ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=561, - response_tokens=41, + input_tokens=561, + output_tokens=41, total_tokens=797, details={'thoughts_tokens': 195, 'text_prompt_tokens': 561}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -802,16 +802,16 @@ async def test_tool_returning_audio_resource_link( ), ModelResponse( parts=[TextPart(content='00:05')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=784, - response_tokens=5, + input_tokens=784, + output_tokens=5, total_tokens=789, details={'text_prompt_tokens': 640, 'audio_prompt_tokens': 144}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -840,10 +840,10 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im tool_call_id='call_Q7xG8CCG0dyevVfUS0ubsDdN', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=190, - response_tokens=11, + input_tokens=190, + output_tokens=11, total_tokens=201, details={ 'accepted_prediction_tokens': 0, @@ -855,7 +855,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloGQJWIX0Qk7gtNzF4s2Fez0O29', + provider_request_id='chatcmpl-BRloGQJWIX0Qk7gtNzF4s2Fez0O29', ), ModelRequest( parts=[ @@ -876,10 +876,10 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ), ModelResponse( parts=[TextPart(content='Here is an image of a sliced kiwi on a white background.')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=1329, - response_tokens=15, + input_tokens=1329, + output_tokens=15, total_tokens=1344, details={ 'accepted_prediction_tokens': 0, @@ -891,7 +891,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloJHR654fSD0fcvLWZxtKtn0pag', + provider_request_id='chatcmpl-BRloJHR654fSD0fcvLWZxtKtn0pag', ), ] ) @@ -914,10 +914,10 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[ToolCallPart(tool_name='get_dict', args='{}', tool_call_id='call_oqKviITBj8PwpQjGyUu4Zu5x')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=195, - response_tokens=11, + input_tokens=195, + output_tokens=11, total_tokens=206, details={ 'accepted_prediction_tokens': 0, @@ -929,7 +929,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloOs7Bb2tq8wJyy9Rv7SQ7L65a7', + provider_request_id='chatcmpl-BRloOs7Bb2tq8wJyy9Rv7SQ7L65a7', ), ModelRequest( parts=[ @@ -943,10 +943,10 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[TextPart(content='{"foo":"bar","baz":123}')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=222, - response_tokens=11, + input_tokens=222, + output_tokens=11, total_tokens=233, details={ 'accepted_prediction_tokens': 0, @@ -958,7 +958,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloPczU1HSCWnreyo21DdNtdOM7L', + provider_request_id='chatcmpl-BRloPczU1HSCWnreyo21DdNtdOM7L', ), ] ) @@ -989,10 +989,10 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): tool_call_id='call_rETXZWddAGZSHyVHAxptPGgc', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=203, - response_tokens=15, + input_tokens=203, + output_tokens=15, total_tokens=218, details={ 'accepted_prediction_tokens': 0, @@ -1004,7 +1004,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloSNg7aGSp1rXDkhInjMIUHKd7A', + provider_request_id='chatcmpl-BRloSNg7aGSp1rXDkhInjMIUHKd7A', ), ModelRequest( parts=[ @@ -1024,10 +1024,10 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): tool_call_id='call_4xGyvdghYKHN8x19KWkRtA5N', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=250, - response_tokens=15, + input_tokens=250, + output_tokens=15, total_tokens=265, details={ 'accepted_prediction_tokens': 0, @@ -1039,7 +1039,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloTvSkFeX4DZKQLqfH9KbQkWlpt', + provider_request_id='chatcmpl-BRloTvSkFeX4DZKQLqfH9KbQkWlpt', ), ModelRequest( parts=[ @@ -1057,10 +1057,10 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): content='I called the tool with the correct parameter, and it returned: "This is not an error."' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=277, - response_tokens=22, + input_tokens=277, + output_tokens=22, total_tokens=299, details={ 'accepted_prediction_tokens': 0, @@ -1072,7 +1072,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloU3MhnqNEqujs28a3ofRbs7VPF', + provider_request_id='chatcmpl-BRloU3MhnqNEqujs28a3ofRbs7VPF', ), ] ) @@ -1095,10 +1095,10 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[ToolCallPart(tool_name='get_none', args='{}', tool_call_id='call_mJTuQ2Cl5SaHPTJbIILEUhJC')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=193, - response_tokens=11, + input_tokens=193, + output_tokens=11, total_tokens=204, details={ 'accepted_prediction_tokens': 0, @@ -1110,7 +1110,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloX2RokWc9j9PAXAuNXGR73WNqY', + provider_request_id='chatcmpl-BRloX2RokWc9j9PAXAuNXGR73WNqY', ), ModelRequest( parts=[ @@ -1124,10 +1124,10 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[TextPart(content='Hello! How can I assist you today?')], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=212, - response_tokens=11, + input_tokens=212, + output_tokens=11, total_tokens=223, details={ 'accepted_prediction_tokens': 0, @@ -1139,7 +1139,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloYWGujk8yE94gfVSsM1T1Ol2Ej', + provider_request_id='chatcmpl-BRloYWGujk8yE94gfVSsM1T1Ol2Ej', ), ] ) @@ -1170,10 +1170,10 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: tool_call_id='call_kL0TvjEVQBDGZrn1Zv7iNYOW', ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=195, - response_tokens=12, + input_tokens=195, + output_tokens=12, total_tokens=207, details={ 'accepted_prediction_tokens': 0, @@ -1185,7 +1185,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlobKLgm6vf79c9O8sloZaYx3coC', + provider_request_id='chatcmpl-BRlobKLgm6vf79c9O8sloZaYx3coC', ), ModelRequest( parts=[ @@ -1215,10 +1215,10 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: content='The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' ) ], - usage=Usage( + usage=RunUsage( requests=1, - request_tokens=1355, - response_tokens=24, + input_tokens=1355, + output_tokens=24, total_tokens=1379, details={ 'accepted_prediction_tokens': 0, @@ -1230,7 +1230,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloepWR5NJpTgSqFBGTSPeM1SWm8', + provider_request_id='chatcmpl-BRloepWR5NJpTgSqFBGTSPeM1SWm8', ), ] ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index e8861a0e01..ed0bb5edf3 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -32,7 +32,7 @@ from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, PromptedOutput, TextOutput -from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_ai.result import AgentStream, FinalResult, RunUsage from pydantic_ai.tools import ToolDefinition from pydantic_graph import End @@ -59,7 +59,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(request_tokens=51, response_tokens=0, total_tokens=51), + usage=RunUsage(input_tokens=51, output_tokens=0, total_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -73,10 +73,10 @@ async def ret_a(x: str) -> str: ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=103, - response_tokens=5, + input_tokens=103, + output_tokens=5, total_tokens=108, ) ) @@ -89,7 +89,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(request_tokens=51, response_tokens=0, total_tokens=51), + usage=RunUsage(input_tokens=51, output_tokens=0, total_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -102,17 +102,17 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(request_tokens=52, response_tokens=11, total_tokens=63), + usage=RunUsage(input_tokens=52, output_tokens=11, total_tokens=63), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=103, - response_tokens=11, + input_tokens=103, + output_tokens=11, total_tokens=114, ) ) @@ -220,43 +220,43 @@ def upcase(text: str) -> str: [ ModelResponse( parts=[TextPart(content='The ')], - usage=Usage(request_tokens=51, response_tokens=1, total_tokens=52), + usage=RunUsage(input_tokens=51, output_tokens=1, total_tokens=52), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat ')], - usage=Usage(request_tokens=51, response_tokens=2, total_tokens=53), + usage=RunUsage(input_tokens=51, output_tokens=2, total_tokens=53), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat ')], - usage=Usage(request_tokens=51, response_tokens=3, total_tokens=54), + usage=RunUsage(input_tokens=51, output_tokens=3, total_tokens=54), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on ')], - usage=Usage(request_tokens=51, response_tokens=4, total_tokens=55), + usage=RunUsage(input_tokens=51, output_tokens=4, total_tokens=55), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on the ')], - usage=Usage(request_tokens=51, response_tokens=5, total_tokens=56), + usage=RunUsage(input_tokens=51, output_tokens=5, total_tokens=56), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on the mat.')], - usage=Usage(request_tokens=51, response_tokens=7, total_tokens=58), + usage=RunUsage(input_tokens=51, output_tokens=7, total_tokens=58), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on the mat.')], - usage=Usage(request_tokens=51, response_tokens=7, total_tokens=58), + usage=RunUsage(input_tokens=51, output_tokens=7, total_tokens=58), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -323,7 +323,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args='{"x": "hello"}', tool_call_id=IsStr())], - usage=Usage(request_tokens=50, response_tokens=5, total_tokens=55), + usage=RunUsage(input_tokens=50, output_tokens=5, total_tokens=55), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -345,7 +345,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args='{"x": "hello"}', tool_call_id=IsStr())], - usage=Usage(request_tokens=50, response_tokens=5, total_tokens=55), + usage=RunUsage(input_tokens=50, output_tokens=5, total_tokens=55), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -367,7 +367,7 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=Usage(request_tokens=50, response_tokens=7, total_tokens=57), + usage=RunUsage(input_tokens=50, output_tokens=7, total_tokens=57), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -420,7 +420,7 @@ async def ret_a(x: str) -> str: # pragma: no cover ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RunUsage(input_tokens=50, output_tokens=1, total_tokens=51), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -476,7 +476,7 @@ def another_tool(y: int) -> int: # pragma: no cover ToolCallPart(tool_name='regular_tool', args='{"x": 1}', tool_call_id=IsStr()), ToolCallPart(tool_name='another_tool', args='{"y": 2}', tool_call_id=IsStr()), ], - usage=Usage(request_tokens=50, response_tokens=10, total_tokens=60), + usage=RunUsage(input_tokens=50, output_tokens=10, total_tokens=60), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), ), @@ -532,7 +532,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt ToolCallPart(tool_name='final_result', args='{"value": "first"}', tool_call_id=IsStr()), ToolCallPart(tool_name='final_result', args='{"value": "second"}', tool_call_id=IsStr()), ], - usage=Usage(request_tokens=50, response_tokens=8, total_tokens=58), + usage=RunUsage(input_tokens=50, output_tokens=8, total_tokens=58), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), ), @@ -599,7 +599,7 @@ def another_tool(y: int) -> int: ToolCallPart(tool_name='final_result', args='{"value": "second"}', tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args='{"value": "???"}', tool_call_id=IsStr()), ], - usage=Usage(request_tokens=50, response_tokens=18, total_tokens=68), + usage=RunUsage(input_tokens=50, output_tokens=18, total_tokens=68), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), ), @@ -708,7 +708,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-call', ), ], - usage=Usage(request_tokens=50, response_tokens=14, total_tokens=64), + usage=RunUsage(input_tokens=50, output_tokens=14, total_tokens=64), model_name='function::sf', timestamp=IsNow(tz=datetime.timezone.utc), kind='response', @@ -779,7 +779,7 @@ def regular_tool(x: int) -> int: ), ModelResponse( parts=[ToolCallPart(tool_name='regular_tool', args={'x': 0}, tool_call_id=IsStr())], - usage=Usage(request_tokens=57, response_tokens=0, total_tokens=57), + usage=RunUsage(input_tokens=57, output_tokens=0, total_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -792,7 +792,7 @@ def regular_tool(x: int) -> int: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'value': 'a'}, tool_call_id=IsStr())], - usage=Usage(request_tokens=58, response_tokens=4, total_tokens=62), + usage=RunUsage(input_tokens=58, output_tokens=4, total_tokens=62), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -848,7 +848,7 @@ def output_validator_simple(data: str) -> str: stream: AgentStream messages: list[str] = [] - stream_usage: Usage | None = None + stream_usage: RunUsage | None = None async with agent.iter('Hello') as run: async for node in run: if agent.is_model_request_node(node): @@ -860,7 +860,7 @@ def output_validator_simple(data: str) -> str: assert ( run.usage() == stream_usage - == Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None) + == RunUsage(requests=1, input_tokens=51, output_tokens=7, total_tokens=58, details=None) ) assert messages == [ @@ -898,9 +898,7 @@ def output_validator_simple(data: str) -> str: assert messages == [ ModelResponse( parts=[TextPart(content=text, part_kind='text')], - usage=Usage( - requests=0, request_tokens=IsInt(), response_tokens=IsInt(), total_tokens=IsInt(), details=None - ), + usage=RunUsage(requests=0, input_tokens=IsInt(), output_tokens=IsInt(), total_tokens=IsInt(), details=None), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', diff --git a/tests/test_tools.py b/tests/test_tools.py index 7f4a45804b..fc60d8a175 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -30,7 +30,7 @@ from pydantic_ai.toolsets.deferred import DeferredToolset from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage from .conftest import IsDatetime, IsStr @@ -1381,7 +1381,7 @@ def get_price(fruit: str) -> ToolReturn: tool_call_id=IsStr(), ), ], - usage=Usage(requests=1, request_tokens=58, response_tokens=10, total_tokens=68), + usage=RunUsage(requests=1, input_tokens=58, output_tokens=10, total_tokens=68), model_name='function:llm:', timestamp=IsDatetime(), ), @@ -1413,7 +1413,7 @@ def get_price(fruit: str) -> ToolReturn: ), ModelResponse( parts=[TextPart(content='Done!')], - usage=Usage(requests=1, request_tokens=76, response_tokens=11, total_tokens=87), + usage=RunUsage(requests=1, input_tokens=76, output_tokens=11, total_tokens=87), model_name='function:llm:', timestamp=IsDatetime(), ), diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index f217b34f4e..e53e65096c 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -20,7 +20,7 @@ from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset from pydantic_ai.toolsets.prepared import PreparedToolset -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage pytestmark = pytest.mark.anyio @@ -31,7 +31,7 @@ def build_run_context(deps: T) -> RunContext[T]: return RunContext( deps=deps, model=TestModel(), - usage=Usage(), + usage=RunUsage(), prompt=None, messages=[], run_step=0, diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 7fb9bba485..d2f05fce61 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -15,7 +15,7 @@ UserPromptPart, ) from pydantic_ai.models.test import TestModel -from pydantic_ai.usage import Usage, UsageLimits +from pydantic_ai.usage import RunUsage, UsageLimits from .conftest import IsNow, IsStr @@ -25,9 +25,7 @@ def test_request_token_limit() -> None: test_agent = Agent(TestModel()) - with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the request_tokens_limit of 5 (request_tokens=59)') - ): + with pytest.raises(UsageLimitExceeded, match=re.escape('Exceeded the request_tokens_limit of 5 (input_tokens=59)')): test_agent.run_sync( 'Hello, this prompt exceeds the request tokens limit.', usage_limits=UsageLimits(request_tokens_limit=5) ) @@ -39,7 +37,7 @@ def test_response_token_limit() -> None: ) with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 5 (response_tokens=11)') + UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 5 (output_tokens=11)') ): test_agent.run_sync('Hello', usage_limits=UsageLimits(response_tokens_limit=5)) @@ -79,7 +77,7 @@ async def ret_a(x: str) -> str: succeeded = False with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') + UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (output_tokens=11)') ): async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: assert test_agent.name == 'test_agent' @@ -95,9 +93,9 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - request_tokens=51, - response_tokens=0, + usage=RunUsage( + input_tokens=51, + output_tokens=0, total_tokens=51, ), model_name='test', @@ -116,10 +114,10 @@ async def ret_a(x: str) -> str: ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=103, - response_tokens=5, + input_tokens=103, + output_tokens=5, total_tokens=108, ) ) @@ -137,7 +135,7 @@ def test_usage_so_far() -> None: test_agent.run_sync( 'Hello, this prompt exceeds the request tokens limit.', usage_limits=UsageLimits(total_tokens_limit=105), - usage=Usage(total_tokens=100), + usage=RunUsage(total_tokens=100), ) @@ -145,20 +143,20 @@ async def test_multi_agent_usage_no_incr(): delegate_agent = Agent(TestModel(), output_type=int) controller_agent1 = Agent(TestModel()) - run_1_usages: list[Usage] = [] + run_1_usages: list[RunUsage] = [] @controller_agent1.tool async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: delegate_result = await delegate_agent.run(sentence) delegate_usage = delegate_result.usage() run_1_usages.append(delegate_usage) - assert delegate_usage == snapshot(Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55)) + assert delegate_usage == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=4, total_tokens=55)) return delegate_result.output result1 = await controller_agent1.run('foobar') assert result1.output == snapshot('{"delegate_to_other_agent1":0}') run_1_usages.append(result1.usage()) - assert result1.usage() == snapshot(Usage(requests=2, request_tokens=103, response_tokens=13, total_tokens=116)) + assert result1.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=13, total_tokens=116)) controller_agent2 = Agent(TestModel()) @@ -166,12 +164,12 @@ async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int: delegate_result = await delegate_agent.run(sentence, usage=ctx.usage) delegate_usage = delegate_result.usage() - assert delegate_usage == snapshot(Usage(requests=2, request_tokens=102, response_tokens=9, total_tokens=111)) + assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9, total_tokens=111)) return delegate_result.output result2 = await controller_agent2.run('foobar') assert result2.output == snapshot('{"delegate_to_other_agent2":0}') - assert result2.usage() == snapshot(Usage(requests=3, request_tokens=154, response_tokens=17, total_tokens=171)) + assert result2.usage() == snapshot(RunUsage(requests=3, input_tokens=154, output_tokens=17, total_tokens=171)) # confirm the usage from result2 is the sum of the usage from result1 assert result2.usage() == functools.reduce(operator.add, run_1_usages) @@ -192,10 +190,10 @@ async def test_multi_agent_usage_sync(): @controller_agent.tool def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: - new_usage = Usage(requests=5, request_tokens=2, response_tokens=3, total_tokens=4) + new_usage = RunUsage(requests=5, input_tokens=2, output_tokens=3, total_tokens=4) ctx.usage.incr(new_usage) return 0 result = await controller_agent.run('foobar') assert result.output == snapshot('{"delegate_to_other_agent":0}') - assert result.usage() == snapshot(Usage(requests=7, request_tokens=105, response_tokens=16, total_tokens=120)) + assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16, total_tokens=120)) diff --git a/uv.lock b/uv.lock index 73e1ef90c1..d997ef0daf 100644 --- a/uv.lock +++ b/uv.lock @@ -1180,6 +1180,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/94/758680531a00d06e471ef649e4ec2ed6bf185356a7f9fbfbb7368a40bd49/fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b", size = 184484, upload-time = "2025-02-01T18:30:19.802Z" }, ] +[[package]] +name = "genai-prices" +version = "0.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eval-type-backport", marker = "python_full_version < '3.11'" }, + { name = "httpx" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/06/398237fceebab444e8f91c974cf88e138e3a9b43d98835df3d2997751ef7/genai_prices-0.0.3.tar.gz", hash = "sha256:9a6a11f64d51e825223613f40dd4ed3312d4b5b3ddec8c030eb814e52ea0f54d", size = 40002, upload-time = "2025-07-13T03:29:42.32Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/ee/7266e6b3b19c9544af4069b94301df2c484fd65e8c1bad4334c9a9b9e1c8/genai_prices-0.0.3-py3-none-any.whl", hash = "sha256:dcae66dec82ccc609e2d4365c3b8b6f192115a39f378d9e5fd5728f21edad80d", size = 41959, upload-time = "2025-07-13T03:29:41.157Z" }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -3124,6 +3138,7 @@ source = { editable = "pydantic_ai_slim" } dependencies = [ { name = "eval-type-backport" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "genai-prices" }, { name = "griffe" }, { name = "httpx" }, { name = "opentelemetry-api" }, @@ -3224,6 +3239,7 @@ requires-dist = [ { name = "eval-type-backport", specifier = ">=0.2.0" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, + { name = "genai-prices", specifier = ">=0.0.3" }, { name = "google-auth", marker = "extra == 'vertexai'", specifier = ">=2.36.0" }, { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.24.0" }, { name = "griffe", specifier = ">=1.3.2" }, From 0b325c3d823454068d081f23dba91ae04d25a963 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 17:18:35 +0200 Subject: [PATCH 02/71] fix --- pydantic_ai_slim/pydantic_ai/agent.py | 2424 ----------------- .../pydantic_ai/agent/__init__.py | 8 +- 2 files changed, 4 insertions(+), 2428 deletions(-) delete mode 100644 pydantic_ai_slim/pydantic_ai/agent.py diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py deleted file mode 100644 index 95b49f0978..0000000000 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ /dev/null @@ -1,2424 +0,0 @@ -from __future__ import annotations as _annotations - -import dataclasses -import inspect -import json -import warnings -from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager -from contextvars import ContextVar -from copy import deepcopy -from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload - -from opentelemetry.trace import NoOpTracer, use_span -from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated - -from pydantic_graph import End, Graph, GraphRun, GraphRunContext -from pydantic_graph._utils import get_event_loop - -from . import ( - _agent_graph, - _output, - _system_prompt, - _utils, - exceptions, - messages as _messages, - models, - result, - usage as _usage, -) -from ._agent_graph import HistoryProcessor -from ._output import OutputToolset -from ._tool_manager import ToolManager -from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model -from .output import OutputDataT, OutputSpec -from .profiles import ModelProfile -from .result import AgentStream, FinalResult, StreamedRunResult -from .settings import ModelSettings, merge_model_settings -from .tools import ( - AgentDepsT, - DocstringFormat, - GenerateToolJsonSchema, - RunContext, - Tool, - ToolFuncContext, - ToolFuncEither, - ToolFuncPlain, - ToolParams, - ToolPrepareFunc, - ToolsPrepareFunc, -) -from .toolsets import AbstractToolset -from .toolsets.combined import CombinedToolset -from .toolsets.function import FunctionToolset -from .toolsets.prepared import PreparedToolset - -# Re-exporting like this improves auto-import behavior in PyCharm -capture_run_messages = _agent_graph.capture_run_messages -EndStrategy = _agent_graph.EndStrategy -CallToolsNode = _agent_graph.CallToolsNode -ModelRequestNode = _agent_graph.ModelRequestNode -UserPromptNode = _agent_graph.UserPromptNode - -if TYPE_CHECKING: - from fasta2a.applications import FastA2A - from fasta2a.broker import Broker - from fasta2a.schema import AgentProvider, Skill - from fasta2a.storage import Storage - from starlette.middleware import Middleware - from starlette.routing import BaseRoute, Route - from starlette.types import ExceptionHandler, Lifespan - - from pydantic_ai.mcp import MCPServer - - from .ag_ui import AGUIApp - -__all__ = ( - 'Agent', - 'AgentRun', - 'AgentRunResult', - 'capture_run_messages', - 'EndStrategy', - 'CallToolsNode', - 'ModelRequestNode', - 'UserPromptNode', - 'InstrumentationSettings', -) - - -T = TypeVar('T') -S = TypeVar('S') -NoneType = type(None) -RunOutputDataT = TypeVar('RunOutputDataT') -"""Type variable for the result data of a run where `output_type` was customized on the run call.""" - - -@final -@dataclasses.dataclass(init=False) -class Agent(Generic[AgentDepsT, OutputDataT]): - """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. - - Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] - and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. - - By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. - - Minimal usage example: - - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - result = agent.run_sync('What is the capital of France?') - print(result.output) - #> Paris - ``` - """ - - model: models.Model | models.KnownModelName | str | None - """The default model configured for this agent. - - We allow `str` here since the actual list of allowed models changes frequently. - """ - - name: str | None - """The name of the agent, used for logging. - - If `None`, we try to infer the agent name from the call frame when the agent is first run. - """ - end_strategy: EndStrategy - """Strategy for handling tool calls when a final result is found.""" - - model_settings: ModelSettings | None - """Optional model request settings to use for this agents's runs, by default. - - Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will - be merged with this value, with the runtime argument taking priority. - """ - - output_type: OutputSpec[OutputDataT] - """ - The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. - """ - - instrument: InstrumentationSettings | bool | None - """Options to automatically instrument with OpenTelemetry.""" - - _instrument_default: ClassVar[InstrumentationSettings | bool] = False - - _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) - _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) - _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) - _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) - _instructions: str | None = dataclasses.field(repr=False) - _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) - _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( - repr=False - ) - _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) - _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) - _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) - _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _max_result_retries: int = dataclasses.field(repr=False) - - _enter_lock: Lock = dataclasses.field(repr=False) - _entered_count: int = dataclasses.field(repr=False) - _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) - - @overload - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - output_type: OutputSpec[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... - - @overload - @deprecated( - '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' - ) - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - result_type: type[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, - result_tool_description: str | None = None, - result_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... - - @overload - @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - result_type: type[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, - result_tool_description: str | None = None, - result_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... - - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads - output_type: Any = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - **_deprecated_kwargs: Any, - ): - """Create an agent. - - Args: - model: The default model to use for this agent, if not provide, - you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. - output_type: The type of the output data, used to validate the data returned by the model, - defaults to `str`. - instructions: Instructions to use for this agent, you can also register instructions via a function with - [`instructions`][pydantic_ai.Agent.instructions]. - system_prompt: Static system prompts to use for this agent, you can also register system - prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. - deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully - parameterize the agent, and therefore get the best out of static type checking. - If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright - or add a type hint `: Agent[None, ]`. - name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame - when the agent is first run. - model_settings: Optional model request settings to use for this agent's runs, by default. - retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. - tools: Tools to register with the agent, you can also register tools via the decorators - [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. - This is useful if you want to customize the definition of multiple tools or you want to register - a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. - This is useful if you want to customize the definition of multiple output tools or you want to register - a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - toolsets: Toolsets to register with the agent, including MCP servers. - defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, - it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, - which checks for the necessary environment variables. Set this to `false` - to defer the evaluation until the first run. Useful if you want to - [override the model][pydantic_ai.Agent.override] for testing. - end_strategy: Strategy for handling tool calls that are requested alongside a final result. - See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. - instrument: Set to True to automatically instrument with OpenTelemetry, - which will use Logfire if it's configured. - Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. - If this isn't set, then the last value set by - [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] - will be used, which defaults to False. - See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. - history_processors: Optional list of callables to process the message history before sending it to the model. - Each processor takes a list of messages and returns a modified list of messages. - Processors can be sync or async and are applied in sequence. - """ - if model is None or defer_model_check: - self.model = model - else: - self.model = models.infer_model(model) - - self.end_strategy = end_strategy - self.name = name - self.model_settings = model_settings - - if 'result_type' in _deprecated_kwargs: - if output_type is not str: # pragma: no cover - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') - - self.output_type = output_type - - self.instrument = instrument - - self._deps_type = deps_type - - self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) - if self._deprecated_result_tool_name is not None: - warnings.warn( - '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', - DeprecationWarning, - stacklevel=2, - ) - - self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) - if self._deprecated_result_tool_description is not None: - warnings.warn( - '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', - DeprecationWarning, - stacklevel=2, - ) - result_retries = _deprecated_kwargs.pop('result_retries', None) - if result_retries is not None: - if output_retries is not None: # pragma: no cover - raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') - warnings.warn( - '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 - ) - output_retries = result_retries - - if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): - if toolsets is not None: # pragma: no cover - raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') - warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) - toolsets = mcp_servers - - _utils.validate_empty_kwargs(_deprecated_kwargs) - - default_output_mode = ( - self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None - ) - - self._output_schema = _output.OutputSchema[OutputDataT].build( - output_type, - default_mode=default_output_mode, - name=self._deprecated_result_tool_name, - description=self._deprecated_result_tool_description, - ) - self._output_validators = [] - - self._instructions = '' - self._instructions_functions = [] - if isinstance(instructions, (str, Callable)): - instructions = [instructions] - for instruction in instructions or []: - if isinstance(instruction, str): - self._instructions += instruction + '\n' - else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) - self._instructions = self._instructions.strip() or None - - self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) - self._system_prompt_functions = [] - self._system_prompt_dynamic_functions = {} - - self._max_result_retries = output_retries if output_retries is not None else retries - self._prepare_tools = prepare_tools - self._prepare_output_tools = prepare_output_tools - - self._output_toolset = self._output_schema.toolset - if self._output_toolset: - self._output_toolset.max_retries = self._max_result_retries - - self._function_toolset = FunctionToolset(tools, max_retries=retries) - self._user_toolsets = toolsets or () - - self.history_processors = history_processors or [] - - self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) - self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) - self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( - '_override_toolsets', default=None - ) - - self._enter_lock = _utils.get_async_lock() - self._entered_count = 0 - self._exit_stack = None - - @staticmethod - def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: - """Set the instrumentation options for all agents where `instrument` is not set.""" - Agent._instrument_default = instrument - - @overload - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[OutputDataT]: ... - - @overload - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AgentRunResult[Any]: - """Run the agent with a user prompt in async mode. - - This method builds an internal agent graph (using system prompts, tools and result schemas) and then - runs the graph to completion. The result of the run is returned. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - - async def main(): - agent_run = await agent.run('What is the capital of France?') - print(agent_run.output) - #> Paris - ``` - - Args: - user_prompt: User input to start/continue the conversation. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - message_history: History of the conversation so far. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. - - Returns: - The result of the run. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') - - _utils.validate_empty_kwargs(_deprecated_kwargs) - - async with self.iter( - user_prompt=user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - toolsets=toolsets, - ) as agent_run: - async for _ in agent_run: - pass - - assert agent_run.result is not None, 'The graph run did not finish properly' - return agent_run.result - - @overload - def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... - - @overload - def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... - - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... - - @asynccontextmanager - async def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: - """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. - - This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an - `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are - executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the - stream of events coming from the execution of tools. - - The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, - and the final result of the run once it has completed. - - For more details, see the documentation of `AgentRun`. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - - async def main(): - nodes = [] - async with agent.iter('What is the capital of France?') as agent_run: - async for node in agent_run: - nodes.append(node) - print(nodes) - ''' - [ - UserPromptNode( - user_prompt='What is the capital of France?', - instructions=None, - instructions_functions=[], - system_prompts=(), - system_prompt_functions=[], - system_prompt_dynamic_functions={}, - ), - ModelRequestNode( - request=ModelRequest( - parts=[ - UserPromptPart( - content='What is the capital of France?', - timestamp=datetime.datetime(...), - ) - ] - ) - ), - CallToolsNode( - model_response=ModelResponse( - parts=[TextPart(content='Paris')], - usage=Usage( - requests=1, input_tokens=56, output_tokens=1, total_tokens=57 - ), - model_name='gpt-4o', - timestamp=datetime.datetime(...), - ) - ), - End(data=FinalResult(output='Paris')), - ] - ''' - print(agent_run.result.output) - #> Paris - ``` - - Args: - user_prompt: User input to start/continue the conversation. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - message_history: History of the conversation so far. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. - - Returns: - The result of the run. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - model_used = self._get_model(model) - del model - - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') - - _utils.validate_empty_kwargs(_deprecated_kwargs) - - deps = self._get_deps(deps) - new_message_index = len(message_history) if message_history else 0 - output_schema = self._prepare_output_schema(output_type, model_used.profile) - - output_type_ = output_type or self.output_type - - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) - - output_toolset = self._output_toolset - if output_schema != self._output_schema or output_validators: - output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) - if output_toolset: - output_toolset.max_retries = self._max_result_retries - output_toolset.output_validators = output_validators - - # Build the graph - graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( - _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) - ) - - # Build the initial state - usage = usage or _usage.RunUsage() - state = _agent_graph.GraphAgentState( - message_history=message_history[:] if message_history else [], - usage=usage, - retries=0, - run_step=0, - ) - - if isinstance(model_used, InstrumentedModel): - instrumentation_settings = model_used.instrumentation_settings - tracer = model_used.instrumentation_settings.tracer - else: - instrumentation_settings = None - tracer = NoOpTracer() - - run_context = RunContext[AgentDepsT]( - deps=deps, - model=model_used, - usage=usage, - prompt=user_prompt, - messages=state.message_history, - tracer=tracer, - trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, - run_step=state.run_step, - ) - - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - # This will raise errors for any name conflicts - run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context) - - # Merge model settings in order of precedence: run > agent > model - merged_settings = merge_model_settings(model_used.settings, self.model_settings) - model_settings = merge_model_settings(merged_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - agent_name = self.name or 'agent' - run_span = tracer.start_span( - 'agent run', - attributes={ - 'model_name': model_used.model_name if model_used else 'no-model', - 'agent_name': agent_name, - 'logfire.msg': f'{agent_name} run', - }, - ) - - async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: - parts = [ - self._instructions, - *[await func.run(run_context) for func in self._instructions_functions], - ] - - model_profile = model_used.profile - if isinstance(output_schema, _output.PromptedOutputSchema): - instructions = output_schema.instructions(model_profile.prompted_output_template) - parts.append(instructions) - - parts = [p for p in parts if p] - if not parts: - return None - return '\n\n'.join(parts).strip() - - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - output_schema=output_schema, - output_validators=output_validators, - history_processors=self.history_processors, - tool_manager=run_toolset, - tracer=tracer, - get_instructions=get_instructions, - instrumentation_settings=instrumentation_settings, - ) - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - instructions=self._instructions, - instructions_functions=self._instructions_functions, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) - - try: - async with graph.iter( - start_node, - state=state, - deps=graph_deps, - span=use_span(run_span) if run_span.is_recording() else None, - infer_name=False, - ) as graph_run: - agent_run = AgentRun(graph_run) - yield agent_run - if (final_result := agent_run.result) is not None and run_span.is_recording(): - if instrumentation_settings and instrumentation_settings.include_content: - run_span.set_attribute( - 'final_result', - ( - final_result.output - if isinstance(final_result.output, str) - else json.dumps(InstrumentedModel.serialize_any(final_result.output)) - ), - ) - finally: - try: - if instrumentation_settings and run_span.is_recording(): - run_span.set_attributes(self._run_span_end_attributes(state, instrumentation_settings)) - finally: - run_span.end() - - def _run_span_end_attributes( - self, state: _agent_graph.GraphAgentState, settings: InstrumentationSettings - ) -> dict[str, str | int]: - return { - 'all_messages_events': json.dumps( - [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] - ), - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - 'all_messages_events': {'type': 'array'}, - 'final_result': {'type': 'object'}, - }, - } - ), - } - - @overload - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[OutputDataT]: ... - - @overload - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AgentRunResult[Any]: - """Synchronously run the agent with a user prompt. - - This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. - You therefore can't use this method inside async code or if there's an active event loop. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - - result_sync = agent.run_sync('What is the capital of Italy?') - print(result_sync.output) - #> Rome - ``` - - Args: - user_prompt: User input to start/continue the conversation. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - message_history: History of the conversation so far. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. - - Returns: - The result of the run. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') - - _utils.validate_empty_kwargs(_deprecated_kwargs) - - return get_event_loop().run_until_complete( - self.run( - user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=False, - toolsets=toolsets, - ) - ) - - @overload - def run_stream( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... - - @overload - def run_stream( - self, - user_prompt: str | Sequence[_messages.UserContent], - *, - output_type: OutputSpec[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... - - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_stream( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... - - @asynccontextmanager - async def run_stream( # noqa C901 - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: - """Run the agent with a user prompt in async mode, returning a streamed response. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - - async def main(): - async with agent.run_stream('What is the capital of the UK?') as response: - print(await response.get_output()) - #> London - ``` - - Args: - user_prompt: User input to start/continue the conversation. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - message_history: History of the conversation so far. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. - - Returns: - The result of the run. - """ - # TODO: We need to deprecate this now that we have the `iter` method. - # Before that, though, we should add an event for when we reach the final result of the stream. - if infer_name and self.name is None: - # f_back because `asynccontextmanager` adds one frame - if frame := inspect.currentframe(): # pragma: no branch - self._infer_name(frame.f_back) - - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') - - _utils.validate_empty_kwargs(_deprecated_kwargs) - - yielded = False - async with self.iter( - user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=False, - toolsets=toolsets, - ) as agent_run: - first_node = agent_run.next_node # start with the first node - assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node - node = first_node - while True: - if self.is_model_request_node(node): - graph_ctx = agent_run.ctx - async with node.stream(graph_ctx) as stream: - - async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: - async for event in stream: - if isinstance(event, _messages.FinalResultEvent): - return FinalResult(s, event.tool_name, event.tool_call_id) - return None - - final_result = await stream_to_final(stream) - if final_result is not None: - if yielded: - raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover - yielded = True - - messages = graph_ctx.state.message_history.copy() - - async def on_complete() -> None: - """Called when the stream has completed. - - The model response will have been added to messages by now - by `StreamedRunResult._marked_completed`. - """ - last_message = messages[-1] - assert isinstance(last_message, _messages.ModelResponse) - tool_calls = [ - part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) - ] - - parts: list[_messages.ModelRequestPart] = [] - async for _event in _agent_graph.process_function_tools( - graph_ctx.deps.tool_manager, - tool_calls, - final_result, - graph_ctx, - parts, - ): - pass - if parts: - messages.append(_messages.ModelRequest(parts)) - - yield StreamedRunResult( - messages, - graph_ctx.deps.new_message_index, - stream, - on_complete, - ) - break - next_node = await agent_run.next(node) - if not isinstance(next_node, _agent_graph.AgentNode): - raise exceptions.AgentRunError( # pragma: no cover - 'Should have produced a StreamedRunResult before getting here' - ) - node = cast(_agent_graph.AgentNode[Any, Any], next_node) - - if not yielded: - raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover - - @contextmanager - def override( - self, - *, - deps: AgentDepsT | _utils.Unset = _utils.UNSET, - model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, or toolsets. - - This is particularly useful when testing. - You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). - - Args: - deps: The dependencies to use instead of the dependencies passed to the agent run. - model: The model to use instead of the model passed to the agent run. - toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. - """ - if _utils.is_set(deps): - deps_token = self._override_deps.set(_utils.Some(deps)) - else: - deps_token = None - - if _utils.is_set(model): - model_token = self._override_model.set(_utils.Some(models.infer_model(model))) - else: - model_token = None - - if _utils.is_set(toolsets): - toolsets_token = self._override_toolsets.set(_utils.Some(toolsets)) - else: - toolsets_token = None - - try: - yield - finally: - if deps_token is not None: - self._override_deps.reset(deps_token) - if model_token is not None: - self._override_model.reset(model_token) - if toolsets_token is not None: - self._override_toolsets.reset(toolsets_token) - - @overload - def instructions( - self, func: Callable[[RunContext[AgentDepsT]], str], / - ) -> Callable[[RunContext[AgentDepsT]], str]: ... - - @overload - def instructions( - self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / - ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... - - @overload - def instructions(self, func: Callable[[], str], /) -> Callable[[], str]: ... - - @overload - def instructions(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... - - @overload - def instructions( - self, / - ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ... - - def instructions( - self, - func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, - /, - ) -> ( - Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] - | _system_prompt.SystemPromptFunc[AgentDepsT] - ): - """Decorator to register an instructions function. - - Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. - Can decorate a sync or async functions. - - The decorator can be used bare (`agent.instructions`). - - Overloads for every possible signature of `instructions` are included so the decorator doesn't obscure - the type of the function. - - Example: - ```python - from pydantic_ai import Agent, RunContext - - agent = Agent('test', deps_type=str) - - @agent.instructions - def simple_instructions() -> str: - return 'foobar' - - @agent.instructions - async def async_instructions(ctx: RunContext[str]) -> str: - return f'{ctx.deps} is the best' - ``` - """ - if func is None: - - def decorator( - func_: _system_prompt.SystemPromptFunc[AgentDepsT], - ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_)) - return func_ - - return decorator - else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(func)) - return func - - @overload - def system_prompt( - self, func: Callable[[RunContext[AgentDepsT]], str], / - ) -> Callable[[RunContext[AgentDepsT]], str]: ... - - @overload - def system_prompt( - self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / - ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... - - @overload - def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ... - - @overload - def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... - - @overload - def system_prompt( - self, /, *, dynamic: bool = False - ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ... - - def system_prompt( - self, - func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, - /, - *, - dynamic: bool = False, - ) -> ( - Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] - | _system_prompt.SystemPromptFunc[AgentDepsT] - ): - """Decorator to register a system prompt function. - - Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. - Can decorate a sync or async functions. - - The decorator can be used either bare (`agent.system_prompt`) or as a function call - (`agent.system_prompt(...)`), see the examples below. - - Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure - the type of the function, see `tests/typed_agent.py` for tests. - - Args: - func: The function to decorate - dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided, - see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref] - - Example: - ```python - from pydantic_ai import Agent, RunContext - - agent = Agent('test', deps_type=str) - - @agent.system_prompt - def simple_system_prompt() -> str: - return 'foobar' - - @agent.system_prompt(dynamic=True) - async def async_system_prompt(ctx: RunContext[str]) -> str: - return f'{ctx.deps} is the best' - ``` - """ - if func is None: - - def decorator( - func_: _system_prompt.SystemPromptFunc[AgentDepsT], - ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: - runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic) - self._system_prompt_functions.append(runner) - if dynamic: # pragma: lax no cover - self._system_prompt_dynamic_functions[func_.__qualname__] = runner - return func_ - - return decorator - else: - assert not dynamic, "dynamic can't be True in this case" - self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) - return func - - @overload - def output_validator( - self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], / - ) -> Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT]: ... - - @overload - def output_validator( - self, func: Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]], / - ) -> Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]]: ... - - @overload - def output_validator( - self, func: Callable[[OutputDataT], OutputDataT], / - ) -> Callable[[OutputDataT], OutputDataT]: ... - - @overload - def output_validator( - self, func: Callable[[OutputDataT], Awaitable[OutputDataT]], / - ) -> Callable[[OutputDataT], Awaitable[OutputDataT]]: ... - - def output_validator( - self, func: _output.OutputValidatorFunc[AgentDepsT, OutputDataT], / - ) -> _output.OutputValidatorFunc[AgentDepsT, OutputDataT]: - """Decorator to register an output validator function. - - Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. - Can decorate a sync or async functions. - - Overloads for every possible signature of `output_validator` are included so the decorator doesn't obscure - the type of the function, see `tests/typed_agent.py` for tests. - - Example: - ```python - from pydantic_ai import Agent, ModelRetry, RunContext - - agent = Agent('test', deps_type=str) - - @agent.output_validator - def output_validator_simple(data: str) -> str: - if 'wrong' in data: - raise ModelRetry('wrong response') - return data - - @agent.output_validator - async def output_validator_deps(ctx: RunContext[str], data: str) -> str: - if ctx.deps in data: - raise ModelRetry('wrong response') - return data - - result = agent.run_sync('foobar', deps='spam') - print(result.output) - #> success (no tool calls) - ``` - """ - self._output_validators.append(_output.OutputValidator[AgentDepsT, Any](func)) - return func - - @deprecated('`result_validator` is deprecated, use `output_validator` instead.') - def result_validator(self, func: Any, /) -> Any: - warnings.warn( - '`result_validator` is deprecated, use `output_validator` instead.', DeprecationWarning, stacklevel=2 - ) - return self.output_validator(func) # type: ignore - - @overload - def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ... - - @overload - def tool( - self, - /, - *, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ... - - def tool( - self, - func: ToolFuncContext[AgentDepsT, ToolParams] | None = None, - /, - *, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> Any: - """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. - - Can decorate a sync or async functions. - - The docstring is inspected to extract both the tool description and description of each parameter, - [learn more](../tools.md#function-tools-and-schema). - - We can't add overloads for every possible signature of tool, since the return type is a recursive union - so the signature of functions decorated with `@agent.tool` is obscured. - - Example: - ```python - from pydantic_ai import Agent, RunContext - - agent = Agent('test', deps_type=int) - - @agent.tool - def foobar(ctx: RunContext[int], x: int) -> int: - return ctx.deps + x - - @agent.tool(retries=2) - async def spam(ctx: RunContext[str], y: float) -> float: - return ctx.deps + y - - result = agent.run_sync('foobar', deps=1) - print(result.output) - #> {"foobar":1,"spam":1.0} - ``` - - Args: - func: The tool function to register. - name: The name of the tool, defaults to the function name. - retries: The number of retries to allow for this tool, defaults to the agent's default retries, - which defaults to 1. - prepare: custom method to prepare the tool definition for each step, return `None` to omit this - tool from a given step. This is useful if you want to customise a tool at call time, - or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. - docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. - Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. - require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. - schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. - strict: Whether to enforce JSON schema compliance (only affects OpenAI). - See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. - """ - - def tool_decorator( - func_: ToolFuncContext[AgentDepsT, ToolParams], - ) -> ToolFuncContext[AgentDepsT, ToolParams]: - # noinspection PyTypeChecker - self._function_toolset.add_function( - func_, - True, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - - return tool_decorator if func is None else tool_decorator(func) - - @overload - def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... - - @overload - def tool_plain( - self, - /, - *, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ... - - def tool_plain( - self, - func: ToolFuncPlain[ToolParams] | None = None, - /, - *, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> Any: - """Decorator to register a tool function which DOES NOT take `RunContext` as an argument. - - Can decorate a sync or async functions. - - The docstring is inspected to extract both the tool description and description of each parameter, - [learn more](../tools.md#function-tools-and-schema). - - We can't add overloads for every possible signature of tool, since the return type is a recursive union - so the signature of functions decorated with `@agent.tool` is obscured. - - Example: - ```python - from pydantic_ai import Agent, RunContext - - agent = Agent('test') - - @agent.tool - def foobar(ctx: RunContext[int]) -> int: - return 123 - - @agent.tool(retries=2) - async def spam(ctx: RunContext[str]) -> float: - return 3.14 - - result = agent.run_sync('foobar', deps=1) - print(result.output) - #> {"foobar":123,"spam":3.14} - ``` - - Args: - func: The tool function to register. - name: The name of the tool, defaults to the function name. - retries: The number of retries to allow for this tool, defaults to the agent's default retries, - which defaults to 1. - prepare: custom method to prepare the tool definition for each step, return `None` to omit this - tool from a given step. This is useful if you want to customise a tool at call time, - or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. - docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. - Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. - require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. - schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. - strict: Whether to enforce JSON schema compliance (only affects OpenAI). - See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. - """ - - def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: - # noinspection PyTypeChecker - self._function_toolset.add_function( - func_, - False, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - - return tool_decorator if func is None else tool_decorator(func) - - def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model: - """Create a model configured for this agent. - - Args: - model: model to use for this run, required if `model` was not set when creating the agent. - - Returns: - The model used - """ - model_: models.Model - if some_model := self._override_model.get(): - # we don't want `override()` to cover up errors from the model not being defined, hence this check - if model is None and self.model is None: - raise exceptions.UserError( - '`model` must either be set on the agent or included when calling it. ' - '(Even when `override(model=...)` is customizing the model that will actually be called)' - ) - model_ = some_model.value - elif model is not None: - model_ = models.infer_model(model) - elif self.model is not None: - # noinspection PyTypeChecker - model_ = self.model = models.infer_model(self.model) - else: - raise exceptions.UserError('`model` must either be set on the agent or included when calling it.') - - instrument = self.instrument - if instrument is None: - instrument = self._instrument_default - - return instrument_model(model_, instrument) - - def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: - """Get deps for a run. - - If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call. - - We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. - """ - if some_deps := self._override_deps.get(): - return some_deps.value - else: - return deps - - def _get_toolset( - self, - output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, - additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractToolset[AgentDepsT]: - """Get the complete toolset. - - Args: - output_toolset: The output toolset to use instead of the one built at agent construction time. - additional_toolsets: Additional toolsets to add. - """ - if some_user_toolsets := self._override_toolsets.get(): - user_toolsets = some_user_toolsets.value - elif additional_toolsets is not None: - user_toolsets = [*self._user_toolsets, *additional_toolsets] - else: - user_toolsets = self._user_toolsets - - all_toolsets = [self._function_toolset, *user_toolsets] - - if self._prepare_tools: - all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)] - - output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset - if output_toolset is not None: - if self._prepare_output_tools: - output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) - all_toolsets = [output_toolset, *all_toolsets] - - return CombinedToolset(all_toolsets) - - def _infer_name(self, function_frame: FrameType | None) -> None: - """Infer the agent name from the call frame. - - Usage should be `self._infer_name(inspect.currentframe())`. - """ - assert self.name is None, 'Name already set' - if function_frame is not None: # pragma: no branch - if parent_frame := function_frame.f_back: # pragma: no branch - for name, item in parent_frame.f_locals.items(): - if item is self: - self.name = name - return - if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch - # if we couldn't find the agent in locals and globals are a different dict, try globals - for name, item in parent_frame.f_globals.items(): - if item is self: - self.name = name - return - - @property - @deprecated( - 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None - ) - def last_run_messages(self) -> list[_messages.ModelMessage]: - raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') - - def _prepare_output_schema( - self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile - ) -> _output.OutputSchema[RunOutputDataT]: - if output_type is not None: - if self._output_validators: - raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') - schema = _output.OutputSchema[RunOutputDataT].build( - output_type, - name=self._deprecated_result_tool_name, - description=self._deprecated_result_tool_description, - default_mode=model_profile.default_structured_output_mode, - ) - else: - schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode) - - schema.raise_if_unsupported(model_profile) - - return schema # pyright: ignore[reportReturnType] - - @staticmethod - def is_model_request_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: - """Check if the node is a `ModelRequestNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.ModelRequestNode) - - @staticmethod - def is_call_tools_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: - """Check if the node is a `CallToolsNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.CallToolsNode) - - @staticmethod - def is_user_prompt_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: - """Check if the node is a `UserPromptNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.UserPromptNode) - - @staticmethod - def is_end_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[End[result.FinalResult[S]]]: - """Check if the node is a `End`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, End) - - async def __aenter__(self) -> Self: - """Enter the agent context. - - This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used. - - This is a no-op if the agent has already been entered. - """ - async with self._enter_lock: - if self._entered_count == 0: - async with AsyncExitStack() as exit_stack: - toolset = self._get_toolset() - await exit_stack.enter_async_context(toolset) - - self._exit_stack = exit_stack.pop_all() - self._entered_count += 1 - return self - - async def __aexit__(self, *args: Any) -> bool | None: - async with self._enter_lock: - self._entered_count -= 1 - if self._entered_count == 0 and self._exit_stack is not None: - await self._exit_stack.aclose() - self._exit_stack = None - - def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None: - """Set the sampling model on all MCP servers registered with the agent. - - If no sampling model is provided, the agent's model will be used. - """ - try: - sampling_model = models.infer_model(model) if model else self._get_model(None) - except exceptions.UserError as e: - raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e - - from .mcp import MCPServer - - def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None: - if isinstance(toolset, MCPServer): - toolset.sampling_model = sampling_model - - self._get_toolset().apply(_set_sampling_model) - - @asynccontextmanager - @deprecated( - '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.' - ) - async def run_mcp_servers( - self, model: models.Model | models.KnownModelName | str | None = None - ) -> AsyncIterator[None]: - """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. - - Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead. - If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model]. - - Returns: a context manager to start and shutdown the servers. - """ - try: - self.set_mcp_sampling_model(model) - except exceptions.UserError: - if model is not None: - raise - - async with self: - yield - - def to_ag_ui( - self, - *, - # Agent.iter parameters - output_type: OutputSpec[OutputDataT] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.RunUsage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - # Starlette - debug: bool = False, - routes: Sequence[BaseRoute] | None = None, - middleware: Sequence[Middleware] | None = None, - exception_handlers: Mapping[Any, ExceptionHandler] | None = None, - on_startup: Sequence[Callable[[], Any]] | None = None, - on_shutdown: Sequence[Callable[[], Any]] | None = None, - lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, - ) -> AGUIApp[AgentDepsT, OutputDataT]: - """Convert the agent to an AG-UI application. - - This allows you to use the agent with a compatible AG-UI frontend. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - app = agent.to_ag_ui() - ``` - - The `app` is an ASGI application that can be used with any ASGI server. - - To run the application, you can use the following command: - - ```bash - uvicorn app:app --host 0.0.0.0 --port 8000 - ``` - - See [AG-UI docs](../ag-ui.md) for more information. - - Args: - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has - no output validators since output validators would expect an argument that matches the agent's - output type. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. - - debug: Boolean indicating if debug tracebacks should be returned on errors. - routes: A list of routes to serve incoming HTTP and WebSocket requests. - middleware: A list of middleware to run for every request. A starlette application will always - automatically include two middleware classes. `ServerErrorMiddleware` is added as the very - outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. - `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled - exception cases occurring in the routing or endpoints. - exception_handlers: A mapping of either integer status codes, or exception class types onto - callables which handle the exceptions. Exception handler callables should be of the form - `handler(request, exc) -> response` and may be either standard functions, or async functions. - on_startup: A list of callables to run on application startup. Startup handler callables do not - take any arguments, and may be either standard functions, or async functions. - on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do - not take any arguments, and may be either standard functions, or async functions. - lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. - This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or - the other, not both. - - Returns: - An ASGI application for running Pydantic AI agents with AG-UI protocol support. - """ - from .ag_ui import AGUIApp - - return AGUIApp( - agent=self, - # Agent.iter parameters - output_type=output_type, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=infer_name, - toolsets=toolsets, - # Starlette - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - on_startup=on_startup, - on_shutdown=on_shutdown, - lifespan=lifespan, - ) - - def to_a2a( - self, - *, - storage: Storage | None = None, - broker: Broker | None = None, - # Agent card - name: str | None = None, - url: str = 'http://localhost:8000', - version: str = '1.0.0', - description: str | None = None, - provider: AgentProvider | None = None, - skills: list[Skill] | None = None, - # Starlette - debug: bool = False, - routes: Sequence[Route] | None = None, - middleware: Sequence[Middleware] | None = None, - exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan[FastA2A] | None = None, - ) -> FastA2A: - """Convert the agent to a FastA2A application. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - app = agent.to_a2a() - ``` - - The `app` is an ASGI application that can be used with any ASGI server. - - To run the application, you can use the following command: - - ```bash - uvicorn app:app --host 0.0.0.0 --port 8000 - ``` - """ - from ._a2a import agent_to_a2a - - return agent_to_a2a( - self, - storage=storage, - broker=broker, - name=name, - url=url, - version=version, - description=description, - provider=provider, - skills=skills, - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - lifespan=lifespan, - ) - - async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: - """Run the agent in a CLI chat interface. - - Args: - deps: The dependencies to pass to the agent. - prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - - Example: - ```python {title="agent_to_cli.py" test="skip"} - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') - - async def main(): - await agent.to_cli() - ``` - """ - from rich.console import Console - - from pydantic_ai._cli import run_chat - - await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) - - def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: - """Run the agent in a CLI chat interface with the non-async interface. - - Args: - deps: The dependencies to pass to the agent. - prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - - ```python {title="agent_to_cli_sync.py" test="skip"} - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') - agent.to_cli_sync() - agent.to_cli_sync(prog_name='assistant') - ``` - """ - return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) - - -@dataclasses.dataclass(repr=False) -class AgentRun(Generic[AgentDepsT, OutputDataT]): - """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent]. - - You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`. - - Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an - [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result] - becomes available. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - - async def main(): - nodes = [] - # Iterate through the run, recording each node along the way: - async with agent.iter('What is the capital of France?') as agent_run: - async for node in agent_run: - nodes.append(node) - print(nodes) - ''' - [ - UserPromptNode( - user_prompt='What is the capital of France?', - instructions=None, - instructions_functions=[], - system_prompts=(), - system_prompt_functions=[], - system_prompt_dynamic_functions={}, - ), - ModelRequestNode( - request=ModelRequest( - parts=[ - UserPromptPart( - content='What is the capital of France?', - timestamp=datetime.datetime(...), - ) - ] - ) - ), - CallToolsNode( - model_response=ModelResponse( - parts=[TextPart(content='Paris')], - usage=Usage( - requests=1, input_tokens=56, output_tokens=1, total_tokens=57 - ), - model_name='gpt-4o', - timestamp=datetime.datetime(...), - ) - ), - End(data=FinalResult(output='Paris')), - ] - ''' - print(agent_run.result.output) - #> Paris - ``` - - You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for - more granular control. - """ - - _graph_run: GraphRun[ - _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[OutputDataT] - ] - - @overload - def _traceparent(self, *, required: Literal[False]) -> str | None: ... - @overload - def _traceparent(self) -> str: ... - def _traceparent(self, *, required: bool = True) -> str | None: - traceparent = self._graph_run._traceparent(required=False) # type: ignore[reportPrivateUsage] - if traceparent is None and required: # pragma: no cover - raise AttributeError('No span was created for this agent run') - return traceparent - - @property - def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: - """The current context of the agent run.""" - return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]( - self._graph_run.state, self._graph_run.deps - ) - - @property - def next_node( - self, - ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: - """The next node that will be run in the agent graph. - - This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. - """ - next_node = self._graph_run.next_node - if isinstance(next_node, End): - return next_node - if _agent_graph.is_agent_node(next_node): - return next_node - raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover - - @property - def result(self) -> AgentRunResult[OutputDataT] | None: - """The final result of the run if it has ended, otherwise `None`. - - Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated - with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. - """ - graph_run_result = self._graph_run.result - if graph_run_result is None: - return None - return AgentRunResult( - graph_run_result.output.output, - graph_run_result.output.tool_name, - graph_run_result.state, - self._graph_run.deps.new_message_index, - self._traceparent(required=False), - ) - - def __aiter__( - self, - ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]]: - """Provide async-iteration over the nodes in the agent run.""" - return self - - async def __anext__( - self, - ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: - """Advance to the next node automatically based on the last returned node.""" - next_node = await self._graph_run.__anext__() - if _agent_graph.is_agent_node(next_node): - return next_node - assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' - return next_node - - async def next( - self, - node: _agent_graph.AgentNode[AgentDepsT, OutputDataT], - ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: - """Manually drive the agent run by passing in the node you want to run next. - - This lets you inspect or mutate the node before continuing execution, or skip certain nodes - under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End] - node. - - Example: - ```python - from pydantic_ai import Agent - from pydantic_graph import End - - agent = Agent('openai:gpt-4o') - - async def main(): - async with agent.iter('What is the capital of France?') as agent_run: - next_node = agent_run.next_node # start with the first node - nodes = [next_node] - while not isinstance(next_node, End): - next_node = await agent_run.next(next_node) - nodes.append(next_node) - # Once `next_node` is an End, we've finished: - print(nodes) - ''' - [ - UserPromptNode( - user_prompt='What is the capital of France?', - instructions=None, - instructions_functions=[], - system_prompts=(), - system_prompt_functions=[], - system_prompt_dynamic_functions={}, - ), - ModelRequestNode( - request=ModelRequest( - parts=[ - UserPromptPart( - content='What is the capital of France?', - timestamp=datetime.datetime(...), - ) - ] - ) - ), - CallToolsNode( - model_response=ModelResponse( - parts=[TextPart(content='Paris')], - usage=Usage( - requests=1, - input_tokens=56, - output_tokens=1, - total_tokens=57, - ), - model_name='gpt-4o', - timestamp=datetime.datetime(...), - ) - ), - End(data=FinalResult(output='Paris')), - ] - ''' - print('Final result:', agent_run.result.output) - #> Final result: Paris - ``` - - Args: - node: The node to run next in the graph. - - Returns: - The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if - the run has completed. - """ - # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it - # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - next_node = await self._graph_run.next(node) - if _agent_graph.is_agent_node(next_node): - return next_node - assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' - return next_node - - def usage(self) -> _usage.RunUsage: - """Get usage statistics for the run so far, including token usage, model requests, and so on.""" - return self._graph_run.state.usage - - def __repr__(self) -> str: # pragma: no cover - result = self._graph_run.result - result_repr = '' if result is None else repr(result.output) - return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>' - - -@dataclasses.dataclass -class AgentRunResult(Generic[OutputDataT]): - """The final result of an agent run.""" - - output: OutputDataT - """The output data from the agent run.""" - - _output_tool_name: str | None = dataclasses.field(repr=False) - _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False) - _new_message_index: int = dataclasses.field(repr=False) - _traceparent_value: str | None = dataclasses.field(repr=False) - - @overload - def _traceparent(self, *, required: Literal[False]) -> str | None: ... - @overload - def _traceparent(self) -> str: ... - def _traceparent(self, *, required: bool = True) -> str | None: - if self._traceparent_value is None and required: # pragma: no cover - raise AttributeError('No span was created for this agent run') - return self._traceparent_value - - @property - @deprecated('`result.data` is deprecated, use `result.output` instead.') - def data(self) -> OutputDataT: - return self.output - - def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: - """Set return content for the output tool. - - Useful if you want to continue the conversation and want to set the response to the output tool call. - """ - if not self._output_tool_name: - raise ValueError('Cannot set output tool return content when the return type is `str`.') - - messages = self._state.message_history - last_message = messages[-1] - for idx, part in enumerate(last_message.parts): - if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name: - # Only do deepcopy when we have to modify - copied_messages = list(messages) - copied_last = deepcopy(last_message) - copied_last.parts[idx].content = return_content # type: ignore[misc] - copied_messages[-1] = copied_last - return copied_messages - - raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.') - - @overload - def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... - - @overload - @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') - def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... - - def all_messages( - self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None - ) -> list[_messages.ModelMessage]: - """Return the history of _messages. - - Args: - output_tool_return_content: The return content of the tool call to set in the last message. - This provides a convenient way to modify the content of the output tool call if you want to continue - the conversation and want to set the response to the output tool call. If `None`, the last message will - not be modified. - result_tool_return_content: Deprecated, use `output_tool_return_content` instead. - - Returns: - List of messages. - """ - content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) - if content is not None: - return self._set_output_tool_return(content) - else: - return self._state.message_history - - @overload - def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ... - - @overload - @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') - def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ... - - def all_messages_json( - self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None - ) -> bytes: - """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. - - Args: - output_tool_return_content: The return content of the tool call to set in the last message. - This provides a convenient way to modify the content of the output tool call if you want to continue - the conversation and want to set the response to the output tool call. If `None`, the last message will - not be modified. - result_tool_return_content: Deprecated, use `output_tool_return_content` instead. - - Returns: - JSON bytes representing the messages. - """ - content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) - return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content)) - - @overload - def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... - - @overload - @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') - def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... - - def new_messages( - self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None - ) -> list[_messages.ModelMessage]: - """Return new messages associated with this run. - - Messages from older runs are excluded. - - Args: - output_tool_return_content: The return content of the tool call to set in the last message. - This provides a convenient way to modify the content of the output tool call if you want to continue - the conversation and want to set the response to the output tool call. If `None`, the last message will - not be modified. - result_tool_return_content: Deprecated, use `output_tool_return_content` instead. - - Returns: - List of new messages. - """ - content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) - return self.all_messages(output_tool_return_content=content)[self._new_message_index :] - - @overload - def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ... - - @overload - @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') - def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ... - - def new_messages_json( - self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None - ) -> bytes: - """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes. - - Args: - output_tool_return_content: The return content of the tool call to set in the last message. - This provides a convenient way to modify the content of the output tool call if you want to continue - the conversation and want to set the response to the output tool call. If `None`, the last message will - not be modified. - result_tool_return_content: Deprecated, use `output_tool_return_content` instead. - - Returns: - JSON bytes representing the new messages. - """ - content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) - return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content)) - - def usage(self) -> _usage.RunUsage: - """Return the usage of the whole run.""" - return self._state.usage diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 5f56bf4db1..a0cfdc9e6b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -425,7 +425,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -441,7 +441,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -457,7 +457,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: @@ -689,7 +689,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: run_span.end() def _run_span_end_attributes( - self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings + self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings ): return { **usage.opentelemetry_attributes(), From fcdf9df9605d80378b7e71619a0a80dd72deb8f2 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 17:25:33 +0200 Subject: [PATCH 03/71] fix --- pydantic_ai_slim/pydantic_ai/agent/wrapper.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/direct.py | 4 ++-- pydantic_ai_slim/pydantic_ai/messages.py | 6 +++--- pydantic_ai_slim/pydantic_ai/run.py | 8 ++++---- tests/models/test_gemini_vertex.py | 6 +++--- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index d796d9dc91..4b7b1c1455 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -72,7 +72,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -88,7 +88,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -104,7 +104,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: @@ -155,7 +155,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( + usage=RunUsage( requests=1, request_tokens=56, response_tokens=7, total_tokens=63 ), model_name='gpt-4o', diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index ff5a947e39..1bd3c20222 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -16,7 +16,7 @@ from datetime import datetime from types import TracebackType -from pydantic_ai.usage import RunUsage +from pydantic_ai.usage import RequestUsage from pydantic_graph._utils import get_event_loop as _get_event_loop from . import agent, messages, models, settings @@ -366,7 +366,7 @@ def get(self) -> messages.ModelResponse: """Build a ModelResponse from the data received from the stream so far.""" return self._ensure_stream_ready().get() - def usage(self) -> RunUsage: + def usage(self) -> RequestUsage: """Get the usage of the response so far.""" return self._ensure_stream_ready().usage() diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 7c248ae02e..66758b44c2 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -849,6 +849,9 @@ class ModelResponse: provider_name: str | None = None """The name of the LLM provider that generated the response.""" + kind: Literal['response'] = 'response' + """Message type identifier, this is available on all parts as a discriminator.""" + provider_details: dict[str, Any] | None = field(default=None) """Additional provider-specific details in a serializable format. @@ -859,9 +862,6 @@ class ModelResponse: provider_request_id: str | None = None """request ID as specified by the model provider. This can be used to track the specific request to the model.""" - kind: Literal['response'] = 'response' - """Message type identifier, this is available on all parts as a discriminator.""" - def price(self) -> genai_types.PriceCalculation: """Calculate the price of the usage, this doesn't use `auto_update` so won't make any network requests.""" assert self.model_name, 'Model name is required to calculate price' diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index f5d5d0ed83..dad446f611 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -66,7 +66,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( + usage=RunUsage( requests=1, request_tokens=56, response_tokens=7, total_tokens=63 ), model_name='gpt-4o', @@ -203,7 +203,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( + usage=RunUsage( requests=1, request_tokens=56, response_tokens=7, @@ -235,7 +235,7 @@ async def main(): assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node - def usage(self) -> _usage.Usage: + def usage(self) -> _usage.RunUsage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage @@ -352,6 +352,6 @@ def new_messages_json(self, *, output_tool_return_content: str | None = None) -> self.new_messages(output_tool_return_content=output_tool_return_content) ) - def usage(self) -> _usage.Usage: + def usage(self) -> _usage.RunUsage: """Return the usage of the whole run.""" return self._state.usage diff --git a/tests/models/test_gemini_vertex.py b/tests/models/test_gemini_vertex.py index a499df18c5..5760faef7e 100644 --- a/tests/models/test_gemini_vertex.py +++ b/tests/models/test_gemini_vertex.py @@ -20,7 +20,7 @@ VideoUrl, ) from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings -from pydantic_ai.usage import RunUsage +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -145,7 +145,7 @@ async def test_url_input( ), ModelResponse( parts=[TextPart(content=Is(expected_output))], - usage=IsInstance(RunUsage), + usage=IsInstance(RequestUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), provider_details={'finish_reason': 'STOP'}, @@ -182,7 +182,7 @@ async def test_url_input_force_download(allow_model_requests: None) -> None: # ), ModelResponse( parts=[TextPart(content=Is(output))], - usage=IsInstance(RunUsage), + usage=IsInstance(RequestUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), provider_details={'finish_reason': 'STOP'}, From c34ebdb03b7003eb9380d199a38ba5cf26564307 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 17:26:20 +0200 Subject: [PATCH 04/71] fix --- pydantic_ai_slim/pydantic_ai/usage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 2b69cb6022..1f1e014bd1 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -178,7 +178,7 @@ def incr(self, incr_usage: RunUsage | RequestUsage) -> None: self.requests += incr_usage.requests return _incr_usage_tokens(self, incr_usage) - def __add__(self, other: RunUsage) -> RunUsage: + def __add__(self, other: RunUsage | RequestUsage) -> RunUsage: """Add two RunUsages together. This is provided so it's trivial to sum usage information from multiple runs. From 1f273aea757a69eececb9b75d3e996fcafd76b5a Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 17:30:50 +0200 Subject: [PATCH 05/71] fix --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e7b10c268c..19b7e666a4 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -342,7 +342,7 @@ async def _make_request( model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) - ctx.state.usage.incr(_usage.RunUsage()) + ctx.state.usage.requests += 1 return self._finish_handling(ctx, model_response) From 0e7415c7681a9f51c5b0149c46f60e3e79368258 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 17:37:12 +0200 Subject: [PATCH 06/71] fix --- docs/agents.md | 23 +++++-------------- docs/direct.md | 4 ++-- docs/message-history.md | 12 +++++----- docs/models/openai.md | 4 ++-- docs/multi-agent-applications.md | 12 +++++----- docs/output.md | 2 +- docs/testing.md | 12 +++------- docs/tools.md | 8 +++---- .../pydantic_ai/agent/__init__.py | 4 +--- .../pydantic_ai/agent/abstract.py | 4 +--- pydantic_ai_slim/pydantic_ai/agent/wrapper.py | 4 +--- pydantic_ai_slim/pydantic_ai/direct.py | 4 ++-- pydantic_ai_slim/pydantic_ai/run.py | 11 ++------- tests/models/test_anthropic.py | 4 ++++ tests/models/test_bedrock.py | 4 ++-- tests/models/test_cohere.py | 6 +++-- tests/models/test_gemini.py | 6 ++--- tests/models/test_google.py | 2 ++ tests/models/test_groq.py | 4 ++-- tests/models/test_huggingface.py | 2 +- tests/models/test_model_test.py | 2 +- tests/models/test_openai.py | 10 +++++--- tests/test_agent.py | 6 ++--- tests/test_usage_limits.py | 10 ++++---- 24 files changed, 70 insertions(+), 90 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 8a7e245a1e..e866de1705 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -302,9 +302,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, input_tokens=56, output_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -367,12 +365,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, - input_tokens=56, - output_tokens=7, - total_tokens=63, - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -570,7 +563,7 @@ result_sync = agent.run_sync( print(result_sync.output) #> Rome print(result_sync.usage()) -#> Usage(requests=1, input_tokens=62, output_tokens=1, total_tokens=63) +#> RunUsage(requests=1, input_tokens=62, output_tokens=1) try: result_sync = agent.run_sync( @@ -579,7 +572,7 @@ try: ) except UsageLimitExceeded as e: print(e) - #> Exceeded the response_tokens_limit of 10 (output_tokens=32) + #> Exceeded the output_tokens_limit of 10 (output_tokens=32) ``` Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling: @@ -1018,9 +1011,7 @@ with capture_run_messages() as messages: # (2)! tool_call_id='pyd_ai_tool_call_id', ) ], - usage=Usage( - requests=1, input_tokens=62, output_tokens=4, total_tokens=66 - ), + usage=RequestUsage(input_tokens=62, output_tokens=4), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -1042,9 +1033,7 @@ with capture_run_messages() as messages: # (2)! tool_call_id='pyd_ai_tool_call_id', ) ], - usage=Usage( - requests=1, input_tokens=72, output_tokens=8, total_tokens=80 - ), + usage=RequestUsage(input_tokens=72, output_tokens=8), model_name='gpt-4o', timestamp=datetime.datetime(...), ), diff --git a/docs/direct.md b/docs/direct.md index f62bec470f..c65e4817bf 100644 --- a/docs/direct.md +++ b/docs/direct.md @@ -28,7 +28,7 @@ model_response = model_request_sync( print(model_response.parts[0].content) #> The capital of France is Paris. print(model_response.usage) -#> Usage(requests=1, input_tokens=56, output_tokens=7, total_tokens=63) +#> RequestUsage(input_tokens=56, output_tokens=7) ``` _(This example is complete, it can be run "as is")_ @@ -83,7 +83,7 @@ async def main(): tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b', ) ], - usage=Usage(requests=1, input_tokens=55, output_tokens=7, total_tokens=62), + usage=RequestUsage(input_tokens=55, output_tokens=7), model_name='gpt-4.1-nano', timestamp=datetime.datetime(...), ) diff --git a/docs/message-history.md b/docs/message-history.md index ee452e52e7..2014612645 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -58,7 +58,7 @@ print(result.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, input_tokens=60, output_tokens=12, total_tokens=72), + usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -126,7 +126,7 @@ async def main(): content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(input_tokens=50, output_tokens=12, total_tokens=62), + usage=RequestUsage(input_tokens=50, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -180,7 +180,7 @@ print(result2.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, input_tokens=60, output_tokens=12, total_tokens=72), + usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -198,7 +198,7 @@ print(result2.all_messages()) content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], - usage=Usage(requests=1, input_tokens=61, output_tokens=26, total_tokens=87), + usage=RequestUsage(input_tokens=61, output_tokens=26), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -299,7 +299,7 @@ print(result2.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, input_tokens=60, output_tokens=12, total_tokens=72), + usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -317,7 +317,7 @@ print(result2.all_messages()) content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], - usage=Usage(requests=1, input_tokens=61, output_tokens=26, total_tokens=87), + usage=RequestUsage(input_tokens=61, output_tokens=26), model_name='gemini-1.5-pro', timestamp=datetime.datetime(...), ), diff --git a/docs/models/openai.md b/docs/models/openai.md index e2a7d84142..af12c2e256 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -272,7 +272,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, input_tokens=57, output_tokens=8, total_tokens=65) +#> RunUsage(requests=1, input_tokens=57, output_tokens=8) ``` #### Example using a remote server @@ -301,7 +301,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, input_tokens=57, output_tokens=8, total_tokens=65) +#> RunUsage(requests=1, input_tokens=57, output_tokens=8) ``` 1. The name of the model running on the remote server diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 348cb8da0c..62a66164ce 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -53,7 +53,7 @@ result = joke_selection_agent.run_sync( print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) -#> Usage(requests=3, input_tokens=204, output_tokens=24, total_tokens=228) +#> RunUsage(requests=3, input_tokens=204, output_tokens=24) ``` 1. The "parent" or controlling agent. @@ -144,7 +144,7 @@ async def main(): print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! - #> Usage(requests=4, input_tokens=309, output_tokens=32, total_tokens=341) + #> RunUsage(requests=4, input_tokens=309, output_tokens=32) ``` 1. Define a dataclass to hold the client and API key dependencies. @@ -188,7 +188,7 @@ from rich.prompt import Prompt from pydantic_ai import Agent, RunContext from pydantic_ai.messages import ModelMessage -from pydantic_ai.usage import Usage, UsageLimits +from pydantic_ai.usage import RunUsage, UsageLimits class FlightDetails(BaseModel): @@ -221,7 +221,7 @@ async def flight_search( usage_limits = UsageLimits(request_limit=15) # (3)! -async def find_flight(usage: Usage) -> Union[FlightDetails, None]: # (4)! +async def find_flight(usage: RunUsage) -> Union[FlightDetails, None]: # (4)! message_history: Union[list[ModelMessage], None] = None for _ in range(3): prompt = Prompt.ask( @@ -259,7 +259,7 @@ seat_preference_agent = Agent[None, Union[SeatPreference, Failed]]( # (5)! ) -async def find_seat(usage: Usage) -> SeatPreference: # (6)! +async def find_seat(usage: RunUsage) -> SeatPreference: # (6)! message_history: Union[list[ModelMessage], None] = None while True: answer = Prompt.ask('What seat would you like?') @@ -278,7 +278,7 @@ async def find_seat(usage: Usage) -> SeatPreference: # (6)! async def main(): # (7)! - usage: Usage = Usage() + usage: RunUsage = RunUsage() opt_flight_details = await find_flight(usage) if opt_flight_details is not None: diff --git a/docs/output.md b/docs/output.md index bf573c43ec..80dcda2453 100644 --- a/docs/output.md +++ b/docs/output.md @@ -24,7 +24,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, input_tokens=57, output_tokens=8, total_tokens=65) +#> RunUsage(requests=1, input_tokens=57, output_tokens=8) ``` _(This example is complete, it can be run "as is")_ diff --git a/docs/testing.md b/docs/testing.md index 68330671b8..8d7cd8313d 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -97,7 +97,7 @@ from pydantic_ai.messages import ( UserPromptPart, ModelRequest, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from fake_database import DatabaseConn from weather_app import run_weather_forecast, weather_agent @@ -141,12 +141,9 @@ async def test_forecast(): tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, + usage=RequestUsage( input_tokens=71, output_tokens=7, - total_tokens=78, - details=None, ), model_name='test', timestamp=IsNow(tz=timezone.utc), @@ -167,12 +164,9 @@ async def test_forecast(): content='{"weather_forecast":"Sunny with a chance of rain"}', ) ], - usage=Usage( - requests=1, + usage=RequestUsage( input_tokens=77, output_tokens=16, - total_tokens=93, - details=None, ), model_name='test', timestamp=IsNow(tz=timezone.utc), diff --git a/docs/tools.md b/docs/tools.md index 506c426905..1cd36ee4ef 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -95,7 +95,7 @@ print(dice_result.all_messages()) tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], - usage=Usage(requests=1, input_tokens=90, output_tokens=2, total_tokens=92), + usage=RequestUsage(input_tokens=90, output_tokens=2), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), @@ -115,7 +115,7 @@ print(dice_result.all_messages()) tool_name='get_player_name', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], - usage=Usage(requests=1, input_tokens=91, output_tokens=4, total_tokens=95), + usage=RequestUsage(input_tokens=91, output_tokens=4), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), @@ -135,9 +135,7 @@ print(dice_result.all_messages()) content="Congratulations Anne, you guessed correctly! You're a winner!" ) ], - usage=Usage( - requests=1, input_tokens=92, output_tokens=12, total_tokens=104 - ), + usage=RequestUsage(input_tokens=92, output_tokens=12), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index a0cfdc9e6b..a2c7d5450e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -508,9 +508,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 341582285e..be357b068d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -612,9 +612,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=RunUsage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 4b7b1c1455..b95eb0d798 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -155,9 +155,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=RunUsage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 1bd3c20222..5f315fe144 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -57,7 +57,7 @@ async def main(): ''' ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage(requests=1, input_tokens=56, output_tokens=1, total_tokens=57), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) @@ -110,7 +110,7 @@ def model_request_sync( ''' ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage(requests=1, input_tokens=56, output_tokens=1, total_tokens=57), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index dad446f611..e5908ce535 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -66,9 +66,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=RunUsage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -203,12 +201,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=RunUsage( - requests=1, - request_tokens=56, - response_tokens=7, - total_tokens=63, - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 91443db479..21597a5047 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -175,6 +175,7 @@ async def test_sync_request_text_response(allow_model_requests: None): assert result.output == 'world' assert result.usage() == snapshot( RunUsage( + requests=1, input_tokens=5, output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}, @@ -187,6 +188,7 @@ async def test_sync_request_text_response(allow_model_requests: None): assert result.output == 'world' assert result.usage() == snapshot( RunUsage( + requests=1, input_tokens=5, output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}, @@ -232,6 +234,7 @@ async def test_async_request_prompt_caching(allow_model_requests: None): assert result.output == 'world' assert result.usage() == snapshot( RunUsage( + requests=1, input_tokens=13, output_tokens=5, details={ @@ -257,6 +260,7 @@ async def test_async_request_text_response(allow_model_requests: None): assert result.output == 'world' assert result.usage() == snapshot( RunUsage( + requests=1, input_tokens=3, output_tokens=5, details={'input_tokens': 3, 'output_tokens': 5}, diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index ef6d8e2f1a..e3148587d2 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -58,7 +58,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro assert result.output == snapshot( "Hello! How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help." ) - assert result.usage() == snapshot(RunUsage(input_tokens=7, output_tokens=30)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=7, output_tokens=30)) assert result.all_messages() == snapshot( [ ModelRequest( @@ -111,7 +111,7 @@ async def temperature(city: str, date: datetime.date) -> str: result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response) assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) - assert result.usage() == snapshot(RunUsage(input_tokens=1236, output_tokens=298)) + assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298)) assert result.all_messages() == snapshot( [ ModelRequest( diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 4238afcdb4..3437a3d0f0 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -104,14 +104,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(RunUsage()) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(RunUsage()) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -149,6 +149,7 @@ async def test_request_simple_usage(allow_model_requests: None): assert result.output == 'world' assert result.usage() == snapshot( RunUsage( + requests=1, input_tokens=1, output_tokens=1, details={ @@ -321,6 +322,7 @@ async def get_location(loc_name: str) -> str: ) assert result.usage() == snapshot( RunUsage( + requests=3, input_tokens=5, output_tokens=3, details={'input_tokens': 4, 'output_tokens': 2}, diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ae2570ab75..f0c3a337f5 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -627,7 +627,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ), ] ) - assert result.usage() == snapshot(RunUsage(input_tokens=1, output_tokens=2)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) result = await agent.run('Hello', message_history=result.new_messages()) assert result.output == 'Hello world' @@ -781,7 +781,7 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot(RunUsage(input_tokens=3, output_tokens=6)) + assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6)) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -1651,7 +1651,7 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): result = await agent.run('Test with thought') assert result.output == 'Hello from thought test' - assert result.usage() == snapshot(RunUsage(input_tokens=1, output_tokens=2)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) @pytest.mark.vcr() diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 7140a866e2..83bf47f1b0 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -73,6 +73,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP assert result.output == snapshot('Hello there! How can I help you today?\n') assert result.usage() == snapshot( RunUsage( + requests=1, input_tokens=7, output_tokens=11, details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, @@ -131,6 +132,7 @@ async def temperature(city: str, date: datetime.date) -> str: assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) assert result.usage() == snapshot( RunUsage( + requests=2, input_tokens=224, output_tokens=35, details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35}, diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 2d1b5dce4a..aa6b76b875 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -142,14 +142,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(RunUsage()) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(RunUsage()) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 379521cb39..1c83434000 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -189,7 +189,7 @@ async def test_request_simple_usage(allow_model_requests: None, huggingface_api_ result.output == "Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, or just want to chat, feel free to let me know!" ) - assert result.usage() == snapshot(RunUsage(input_tokens=30, output_tokens=40)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=30, output_tokens=40)) async def test_request_structured_response( diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index bfcdcf320d..4daf179f05 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -339,4 +339,4 @@ def test_different_content_input(content: AudioUrl | VideoUrl | ImageUrl | Binar agent = Agent() result = agent.run_sync(['x', content], model=TestModel(custom_output_text='custom')) assert result.output == snapshot('custom') - assert result.usage() == snapshot(RunUsage(input_tokens=51, output_tokens=1)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=1)) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index cb9e9b69ff..cb51be0dd1 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -174,14 +174,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(RunUsage()) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(RunUsage()) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -233,7 +233,10 @@ async def test_request_simple_usage(allow_model_requests: None): assert result.output == 'world' assert result.usage() == snapshot( RunUsage( - input_tokens=2, details={'completion_tokens': 1, 'prompt_tokens': 2, 'total_tokens': 3}, output_tokens=1 + requests=1, + input_tokens=2, + details={'completion_tokens': 1, 'prompt_tokens': 2, 'total_tokens': 3}, + output_tokens=1, ) ) @@ -416,6 +419,7 @@ async def get_location(loc_name: str) -> str: ) assert result.usage() == snapshot( RunUsage( + requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, diff --git a/tests/test_agent.py b/tests/test_agent.py index c111cc9f17..f009f8ff48 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1771,7 +1771,7 @@ async def ret_a(x: str) -> str: assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] assert result2.output == snapshot('{"ret_a":"a-apple"}') assert result2._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] - assert result2.usage() == snapshot(RunUsage(input_tokens=55, output_tokens=13, details=None)) + assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13, details=None)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -1828,7 +1828,7 @@ async def ret_a(x: str) -> str: assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] assert result3.output == snapshot('{"ret_a":"a-apple"}') assert result3._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] - assert result3.usage() == snapshot(RunUsage(input_tokens=55, output_tokens=13, details=None)) + assert result3.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13, details=None)) def test_run_with_history_new_structured(): @@ -1955,7 +1955,7 @@ async def ret_a(x: str) -> str: assert result2.output == snapshot(Response(a=0)) assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] - assert result2.usage() == snapshot(RunUsage(input_tokens=59, output_tokens=13, details=None)) + assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13, details=None)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 5f04bedf7a..58808750b9 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -143,13 +143,13 @@ async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: delegate_result = await delegate_agent.run(sentence) delegate_usage = delegate_result.usage() run_1_usages.append(delegate_usage) - assert delegate_usage == snapshot(RunUsage(input_tokens=51, output_tokens=4)) + assert delegate_usage == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=4)) return delegate_result.output result1 = await controller_agent1.run('foobar') assert result1.output == snapshot('{"delegate_to_other_agent1":0}') run_1_usages.append(result1.usage()) - assert result1.usage() == snapshot(RunUsage(input_tokens=103, output_tokens=13)) + assert result1.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=13)) controller_agent2 = Agent(TestModel()) @@ -157,12 +157,12 @@ async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int: delegate_result = await delegate_agent.run(sentence, usage=ctx.usage) delegate_usage = delegate_result.usage() - assert delegate_usage == snapshot(RunUsage(input_tokens=102, output_tokens=9)) + assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9)) return delegate_result.output result2 = await controller_agent2.run('foobar') assert result2.output == snapshot('{"delegate_to_other_agent2":0}') - assert result2.usage() == snapshot(RunUsage(input_tokens=154, output_tokens=17)) + assert result2.usage() == snapshot(RunUsage(requests=3, input_tokens=154, output_tokens=17)) # confirm the usage from result2 is the sum of the usage from result1 assert result2.usage() == functools.reduce(operator.add, run_1_usages) @@ -189,4 +189,4 @@ def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: result = await controller_agent.run('foobar') assert result.output == snapshot('{"delegate_to_other_agent":0}') - assert result.usage() == snapshot(RunUsage(requests=5, input_tokens=105, output_tokens=16)) + assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16)) From 7a6d333f3c2ebf2f5a6e0840d78c2ac70ef4496e Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 18:03:19 +0200 Subject: [PATCH 07/71] fix --- tests/evals/test_dataset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index 8983ec7d62..0e96329489 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -1177,13 +1177,13 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput: for span in spans: span['attributes'].pop('code.filepath', None) span['attributes'].pop('code.function', None) + span['attributes'].pop('code.lineno', None) assert [(span['name'], span['attributes']) for span in spans] == snapshot( [ ( 'evaluate {name}', { - 'code.lineno': 123, 'name': 'mock_async_task', 'logfire.msg_template': 'evaluate {name}', 'logfire.msg': 'evaluate mock_async_task', @@ -1196,7 +1196,6 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput: ( 'case: {case_name}', { - 'code.lineno': 123, 'task_name': 'mock_async_task', 'case_name': 'case1', 'inputs': '{"query":"What is 2+2?"}', @@ -1218,7 +1217,6 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput: ( 'execute {task}', { - 'code.lineno': 123, 'task': 'mock_async_task', 'logfire.msg_template': 'execute {task}', 'logfire.msg': 'execute mock_async_task', @@ -1229,7 +1227,6 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput: ( 'case: {case_name}', { - 'code.lineno': 123, 'task_name': 'mock_async_task', 'case_name': 'case2', 'inputs': '{"query":"What is the capital of France?"}', @@ -1251,7 +1248,6 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput: ( 'execute {task}', { - 'code.lineno': 123, 'task': 'mock_async_task', 'logfire.msg_template': 'execute {task}', 'logfire.msg': 'execute mock_async_task', From 3f38c9a963bc8deb7bfca20b4af39083bc44f18f Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 18:09:39 +0200 Subject: [PATCH 08/71] fix --- docs/agents.md | 2 +- docs/output.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index e866de1705..03a0f91f97 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -384,7 +384,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci #### Accessing usage and final output -You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`][pydantic_ai.agent.AgentRun] object via `agent_run.usage()`. This method returns a [`Usage`][pydantic_ai.usage.Usage] object containing the usage data. +You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`][pydantic_ai.agent.AgentRun] object via `agent_run.usage()`. This method returns a [`RunUsage`][pydantic_ai.usage.RunUsage] object containing the usage data. Once the run finishes, `agent_run.result` becomes a [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] object containing the final output (and related metadata). diff --git a/docs/output.md b/docs/output.md index 80dcda2453..28a46019ae 100644 --- a/docs/output.md +++ b/docs/output.md @@ -1,6 +1,6 @@ "Output" refers to the final value returned from [running an agent](agents.md#running-agents). This can be either plain text, [structured data](#structured-output), or the result of a [function](#output-functions) called with arguments provided by the model. -The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so that you can access other data, like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results). +The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so that you can access other data, like [usage][pydantic_ai.usage.RunUsage] of the run and [message history](message-history.md#accessing-messages-from-results). Both `AgentRunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. From 5754d89aa5caa989be1d897fafe5a5d8bb5e8b90 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 19:08:23 +0200 Subject: [PATCH 09/71] Simplify diff --- pydantic_ai_slim/pydantic_ai/messages.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 66758b44c2..915fa956f7 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -846,12 +846,12 @@ class ModelResponse: If the model provides a timestamp in the response (as OpenAI does) that will be used. """ - provider_name: str | None = None - """The name of the LLM provider that generated the response.""" - kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" + provider_name: str | None = None + """The name of the LLM provider that generated the response.""" + provider_details: dict[str, Any] | None = field(default=None) """Additional provider-specific details in a serializable format. From 4c78e4f464eedb4fea36983d798ff32e08f2a99f Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 19:15:00 +0200 Subject: [PATCH 10/71] Simplify diff --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4a2819afdf..ccdd81eadc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,6 @@ __all__ = 'IsDatetime', 'IsFloat', 'IsNow', 'IsStr', 'IsInt', 'IsInstance', 'TestEnv', 'ClientWithHandler', 'try_import' - # Configure VCR logger to WARNING as it is too verbose by default # specifically, it logs every request and response including binary # content in Cassette.append, which is causing log downloads from From 851a582dbefd43aba501a4e09fe2151fe9fb1f39 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 19:41:44 +0200 Subject: [PATCH 11/71] Simplify diff --- pydantic_ai_slim/pydantic_ai/result.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 3e6e8edef0..39aa5395f6 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -22,7 +22,7 @@ ToolOutputSchema, ) from ._run_context import AgentDepsT, RunContext -from .messages import AgentStreamEvent, FinalResultEvent +from .messages import AgentStreamEvent from .output import ( OutputDataT, ToolOutput, @@ -52,7 +52,6 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _tool_manager: ToolManager[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) - _final_result_event: FinalResultEvent | None = field(default=None, init=False) _initial_run_ctx_usage: RunUsage = field(init=False) def __post_init__(self): From d2cf1ed4f91d888a819bf751f07d942166ab5359 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 19:45:53 +0200 Subject: [PATCH 12/71] Simplify diff --- tests/models/test_groq.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index aa6b76b875..fe987c490d 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -246,6 +246,11 @@ async def test_request_tool_call(allow_model_requests: None): ) ], ), + usage=CompletionUsage( + completion_tokens=1, + prompt_tokens=2, + total_tokens=3, + ), ), completion_message( ChatCompletionMessage( @@ -259,6 +264,11 @@ async def test_request_tool_call(allow_model_requests: None): ) ], ), + usage=CompletionUsage( + completion_tokens=2, + prompt_tokens=3, + total_tokens=6, + ), ), completion_message(ChatCompletionMessage(content='final response', role='assistant')), ] @@ -291,6 +301,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], + usage=RequestUsage(input_tokens=2, output_tokens=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), provider_request_id='123', @@ -313,6 +324,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], + usage=RequestUsage(input_tokens=3, output_tokens=2), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), provider_request_id='123', From 10e91dfbda6a071c926dca614cf03a91074bdf22 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 19:47:47 +0200 Subject: [PATCH 13/71] Simplify diff --- tests/models/test_groq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index fe987c490d..f8b123afeb 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -926,7 +926,7 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key usage=RequestUsage(input_tokens=21, output_tokens=1414), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - provider_request_id=IsStr(), + provider_request_id='chatcmpl-9748c1af-1065-410a-969a-d7fb48039fbb', ), ModelRequest( parts=[ @@ -942,7 +942,7 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key usage=RequestUsage(input_tokens=524, output_tokens=1590), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - provider_request_id=IsStr(), + provider_request_id='chatcmpl-994aa228-883a-498c-8b20-9655d770b697', ), ] ) From e6b54987a401d808f795d474b16fce0f5edef220 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 19:57:53 +0200 Subject: [PATCH 14/71] Simplify diff --- tests/test_streaming.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 01179ece24..f0cb178bb8 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -10,7 +10,6 @@ from typing import Any, Union import pytest -from dirty_equals import IsInstance from inline_snapshot import snapshot from pydantic import BaseModel @@ -42,7 +41,7 @@ from pydantic_ai.usage import RequestUsage from pydantic_graph import End -from .conftest import IsNow, IsStr +from .conftest import IsInt, IsNow, IsStr pytestmark = pytest.mark.anyio @@ -898,7 +897,7 @@ def output_validator_simple(data: str) -> str: assert messages == [ ModelResponse( parts=[TextPart(content=text, part_kind='text')], - usage=IsInstance(RequestUsage), # type: ignore + usage=RequestUsage(input_tokens=IsInt(), output_tokens=IsInt()), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', From c54352df7c743267b0c51c39a0e2627b9bdf82dc Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 20:04:59 +0200 Subject: [PATCH 15/71] Simplify diff --- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +- tests/models/test_deepseek.py | 3 - tests/models/test_openai.py | 102 ++---------------- 3 files changed, 11 insertions(+), 100 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 54a5fd1172..394bed83a1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1258,7 +1258,11 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R ) else: details = { - key: value for key, value in response_usage.model_dump(exclude_none=True).items() if isinstance(value, int) + key: value + for key, value in response_usage.model_dump( + exclude_none=True, exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'} + ).items() + if isinstance(value, int) } u = usage.RequestUsage( input_tokens=response_usage.prompt_tokens, diff --git a/tests/models/test_deepseek.py b/tests/models/test_deepseek.py index 0da81fca06..90dee7623f 100644 --- a/tests/models/test_deepseek.py +++ b/tests/models/test_deepseek.py @@ -49,9 +49,6 @@ async def test_deepseek_model_thinking_part(allow_model_requests: None, deepseek cache_read_tokens=0, output_tokens=789, details={ - 'completion_tokens': 789, - 'prompt_tokens': 12, - 'total_tokens': 801, 'prompt_cache_hit_tokens': 0, 'prompt_cache_miss_tokens': 12, 'reasoning_tokens': 415, diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index cb51be0dd1..085f7e2b45 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -235,7 +235,6 @@ async def test_request_simple_usage(allow_model_requests: None): RunUsage( requests=1, input_tokens=2, - details={'completion_tokens': 1, 'prompt_tokens': 2, 'total_tokens': 3}, output_tokens=1, ) ) @@ -365,7 +364,7 @@ async def get_location(loc_name: str) -> str: input_tokens=2, cache_read_tokens=1, output_tokens=1, - details={'completion_tokens': 1, 'prompt_tokens': 2, 'total_tokens': 3}, + details={}, ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), @@ -393,7 +392,7 @@ async def get_location(loc_name: str) -> str: input_tokens=3, cache_read_tokens=2, output_tokens=2, - details={'completion_tokens': 2, 'prompt_tokens': 3, 'total_tokens': 6}, + details={}, ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), @@ -417,15 +416,7 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot( - RunUsage( - requests=3, - cache_read_tokens=3, - input_tokens=5, - output_tokens=3, - details={'completion_tokens': 3, 'prompt_tokens': 5, 'total_tokens': 9}, - ) - ) + assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3)) FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] @@ -458,14 +449,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot( - RunUsage( - requests=1, - input_tokens=6, - output_tokens=3, - details={'completion_tokens': 3, 'prompt_tokens': 6, 'total_tokens': 9}, - ) - ) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -537,14 +521,7 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot( - RunUsage( - requests=1, - input_tokens=20, - output_tokens=10, - details={'completion_tokens': 10, 'prompt_tokens': 20, 'total_tokens': 30}, - ) - ) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=20, output_tokens=10)) # double check usage matches stream count assert result.usage().output_tokens == len(stream) @@ -693,14 +670,7 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot( - RunUsage( - requests=1, - input_tokens=6, - output_tokens=3, - details={'completion_tokens': 3, 'prompt_tokens': 6, 'total_tokens': 9}, - ) - ) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3)) @pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None]) @@ -857,9 +827,6 @@ async def get_image() -> ImageUrl: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 46, - 'total_tokens': 57, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -897,9 +864,6 @@ async def get_image() -> ImageUrl: output_tokens=8, input_audio_tokens=0, details={ - 'completion_tokens': 8, - 'prompt_tokens': 503, - 'total_tokens': 511, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -944,9 +908,6 @@ async def get_image() -> BinaryContent: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 46, - 'total_tokens': 57, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -982,9 +943,6 @@ async def get_image() -> BinaryContent: output_tokens=9, input_audio_tokens=0, details={ - 'completion_tokens': 9, - 'prompt_tokens': 1185, - 'total_tokens': 1194, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1874,9 +1832,6 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s output_tokens=8, input_audio_tokens=0, details={ - 'completion_tokens': 8, - 'prompt_tokens': 24, - 'total_tokens': 32, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1923,9 +1878,6 @@ async def get_temperature(city: str) -> float: output_tokens=15, input_audio_tokens=0, details={ - 'completion_tokens': 15, - 'prompt_tokens': 50, - 'total_tokens': 65, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1952,9 +1904,6 @@ async def get_temperature(city: str) -> float: output_tokens=15, input_audio_tokens=0, details={ - 'completion_tokens': 15, - 'prompt_tokens': 75, - 'total_tokens': 90, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2113,9 +2062,6 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api output_tokens=2437, input_audio_tokens=0, details={ - 'completion_tokens': 2437, - 'prompt_tokens': 822, - 'total_tokens': 3259, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 1792, @@ -2425,9 +2371,6 @@ async def get_user_country() -> str: output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 68, - 'total_tokens': 80, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2462,9 +2405,6 @@ async def get_user_country() -> str: output_tokens=36, input_audio_tokens=0, details={ - 'completion_tokens': 36, - 'prompt_tokens': 89, - 'total_tokens': 125, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2525,9 +2465,6 @@ async def get_user_country() -> str: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 42, - 'total_tokens': 53, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2556,9 +2493,6 @@ async def get_user_country() -> str: output_tokens=10, input_audio_tokens=0, details={ - 'completion_tokens': 10, - 'prompt_tokens': 63, - 'total_tokens': 73, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2612,9 +2546,6 @@ async def get_user_country() -> str: output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 71, - 'total_tokens': 83, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2643,9 +2574,6 @@ async def get_user_country() -> str: output_tokens=15, input_audio_tokens=0, details={ - 'completion_tokens': 15, - 'prompt_tokens': 92, - 'total_tokens': 107, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2701,9 +2629,6 @@ async def get_user_country() -> str: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 160, - 'total_tokens': 171, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2736,9 +2661,6 @@ async def get_user_country() -> str: output_tokens=25, input_audio_tokens=0, details={ - 'completion_tokens': 25, - 'prompt_tokens': 181, - 'total_tokens': 206, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2797,9 +2719,6 @@ async def get_user_country() -> str: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 109, - 'total_tokens': 120, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2835,9 +2754,6 @@ async def get_user_country() -> str: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 130, - 'total_tokens': 141, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2900,9 +2816,6 @@ async def get_user_country() -> str: output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 273, - 'total_tokens': 284, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -2942,9 +2855,6 @@ async def get_user_country() -> str: output_tokens=21, input_audio_tokens=0, details={ - 'completion_tokens': 21, - 'prompt_tokens': 294, - 'total_tokens': 315, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, From 4c0c5d8d53b8512a75b84bb00062767bfe4499c6 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 20:27:10 +0200 Subject: [PATCH 16/71] Remove logfire.configure calls --- tests/evals/test_otel.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/evals/test_otel.py b/tests/evals/test_otel.py index 10fec9f7bb..08e035dc32 100644 --- a/tests/evals/test_otel.py +++ b/tests/evals/test_otel.py @@ -336,8 +336,6 @@ async def test_span_node_repr(span_tree: SpanTree): async def test_span_tree_ancestors_methods(): """Test the ancestor traversal methods in SpanNode.""" - # Configure logfire - logfire.configure() # Create spans with a deep structure for testing ancestor methods with context_subtree() as tree: @@ -398,8 +396,6 @@ async def test_span_tree_ancestors_methods(): async def test_span_tree_descendants_methods(): """Test the descendant traversal methods in SpanNode.""" - # Configure logfire - logfire.configure() # Create spans with a deep structure for testing descendant methods with context_subtree() as tree: @@ -488,8 +484,6 @@ async def test_span_tree_descendants_methods(): async def test_log_levels_and_exceptions(): """Test recording different log levels and exceptions in spans.""" - # Configure logfire - logfire.configure() with context_subtree() as tree: # Test different log levels @@ -891,7 +885,6 @@ async def test_context_subtree_not_configured(mocker: MockerFixture): """Test that context_subtree correctly records spans in independent async contexts.""" from opentelemetry.trace import ProxyTracerProvider - # from opentelemetry.sdk.trace import TracerProvider mocker.patch( 'pydantic_evals.otel._context_in_memory_span_exporter.get_tracer_provider', return_value=ProxyTracerProvider() ) From 15404b64807a795cb42ca4432d826d622bb283b8 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 20:45:57 +0200 Subject: [PATCH 17/71] fix --- tests/test_mcp.py | 69 ----------------------------------------------- 1 file changed, 69 deletions(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 200daafa4c..4b71b06c31 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -207,9 +207,6 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) output_tokens=19, input_audio_tokens=0, details={ - 'completion_tokens': 19, - 'prompt_tokens': 195, - 'total_tokens': 214, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -238,9 +235,6 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) output_tokens=13, input_audio_tokens=0, details={ - 'completion_tokens': 13, - 'prompt_tokens': 227, - 'total_tokens': 240, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -347,9 +341,6 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): output_tokens=18, input_audio_tokens=0, details={ - 'completion_tokens': 18, - 'prompt_tokens': 194, - 'total_tokens': 212, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -382,9 +373,6 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): output_tokens=19, input_audio_tokens=0, details={ - 'completion_tokens': 19, - 'prompt_tokens': 234, - 'total_tokens': 253, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -428,9 +416,6 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 200, - 'total_tokens': 212, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -459,9 +444,6 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 224, - 'total_tokens': 236, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -505,9 +487,6 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 305, - 'total_tokens': 317, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -536,9 +515,6 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 332, - 'total_tokens': 343, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -584,9 +560,6 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 191, - 'total_tokens': 203, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -620,9 +593,6 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: output_tokens=19, input_audio_tokens=0, details={ - 'completion_tokens': 19, - 'prompt_tokens': 1332, - 'total_tokens': 1351, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -670,9 +640,6 @@ async def test_tool_returning_image_resource_link( output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 305, - 'total_tokens': 317, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -706,9 +673,6 @@ async def test_tool_returning_image_resource_link( output_tokens=29, input_audio_tokens=0, details={ - 'completion_tokens': 29, - 'prompt_tokens': 1452, - 'total_tokens': 1481, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -858,9 +822,6 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 190, - 'total_tokens': 201, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -896,9 +857,6 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im output_tokens=15, input_audio_tokens=0, details={ - 'completion_tokens': 15, - 'prompt_tokens': 1329, - 'total_tokens': 1344, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -936,9 +894,6 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 195, - 'total_tokens': 206, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -967,9 +922,6 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 222, - 'total_tokens': 233, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1015,9 +967,6 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): output_tokens=15, input_audio_tokens=0, details={ - 'completion_tokens': 15, - 'prompt_tokens': 203, - 'total_tokens': 218, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1052,9 +1001,6 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): output_tokens=15, input_audio_tokens=0, details={ - 'completion_tokens': 15, - 'prompt_tokens': 250, - 'total_tokens': 265, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1087,9 +1033,6 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): output_tokens=22, input_audio_tokens=0, details={ - 'completion_tokens': 22, - 'prompt_tokens': 277, - 'total_tokens': 299, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1127,9 +1070,6 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 193, - 'total_tokens': 204, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1158,9 +1098,6 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): output_tokens=11, input_audio_tokens=0, details={ - 'completion_tokens': 11, - 'prompt_tokens': 212, - 'total_tokens': 223, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1206,9 +1143,6 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: output_tokens=12, input_audio_tokens=0, details={ - 'completion_tokens': 12, - 'prompt_tokens': 195, - 'total_tokens': 207, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, @@ -1253,9 +1187,6 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: output_tokens=24, input_audio_tokens=0, details={ - 'completion_tokens': 24, - 'prompt_tokens': 1355, - 'total_tokens': 1379, 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, From 516ea4957869ba1d196a5aa337d563a40f8e9c24 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 20:49:07 +0200 Subject: [PATCH 18/71] Disable instrumentation in tests --- tests/conftest.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index ccdd81eadc..31fb9ce2dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ from vcr import VCR, request as vcr_request import pydantic_ai.models +from pydantic_ai import Agent from pydantic_ai.messages import BinaryContent from pydantic_ai.models import Model, cached_async_http_client @@ -230,6 +231,11 @@ def event_loop() -> Iterator[None]: new_loop.close() +@pytest.fixture(autouse=True) +def no_instrumentation_by_default(): + Agent.instrument_all(False) + + def raise_if_exception(e: Any) -> None: if isinstance(e, Exception): raise e From 2bbeadc6b3cfa909d5bacb1d38c5d49d827c6b68 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 20:58:57 +0200 Subject: [PATCH 19/71] docstrings --- pydantic_ai_slim/pydantic_ai/usage.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 1f1e014bd1..fea2c18e7e 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -14,16 +14,11 @@ class UsageBase: input_tokens: int | None = None - """Total number of text input/prompt tokens.""" - output_tokens: int | None = None - """Total number of text output/completion tokens.""" - details: dict[str, int] | None = None - """Any extra details returned by the model.""" def opentelemetry_attributes(self) -> dict[str, int]: - """Get the token limits as OpenTelemetry attributes.""" + """Get the token usage values as OpenTelemetry attributes.""" result: dict[str, int] = {} if self.input_tokens: result['gen_ai.usage.input_tokens'] = self.input_tokens From 1a1b4aeb5d124ded14a9daa95c64d1089bb08883 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 21:07:15 +0200 Subject: [PATCH 20/71] shutdown logfire after each test --- tests/conftest.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 31fb9ce2dc..c09d32f0cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable import httpx +import logfire import pytest from _pytest.assertion.rewrite import AssertionRewritingHook from pytest_mock import MockerFixture @@ -231,9 +232,13 @@ def event_loop() -> Iterator[None]: new_loop.close() +logfire.DEFAULT_LOGFIRE_INSTANCE.config.ignore_no_config = True + + @pytest.fixture(autouse=True) def no_instrumentation_by_default(): Agent.instrument_all(False) + logfire.shutdown(flush=False) def raise_if_exception(e: Any) -> None: From 79f04d8ebca9ce3b50128bf2e4aef8a698402ff6 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 21:09:35 +0200 Subject: [PATCH 21/71] shutdown logfire after each test --- tests/conftest.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c09d32f0cd..9c0b6a5abc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Any, Callable import httpx -import logfire import pytest from _pytest.assertion.rewrite import AssertionRewritingHook from pytest_mock import MockerFixture @@ -232,13 +231,22 @@ def event_loop() -> Iterator[None]: new_loop.close() -logfire.DEFAULT_LOGFIRE_INSTANCE.config.ignore_no_config = True - - @pytest.fixture(autouse=True) def no_instrumentation_by_default(): Agent.instrument_all(False) - logfire.shutdown(flush=False) + + +try: + import logfire + + logfire.DEFAULT_LOGFIRE_INSTANCE.config.ignore_no_config = True + + @pytest.fixture(autouse=True) + def fresh_logfire(): + logfire.shutdown(flush=False) + +except ImportError: + pass def raise_if_exception(e: Any) -> None: From 0874e88d033d2f39ba8eedef2c014e544ee0068c Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 21:38:58 +0200 Subject: [PATCH 22/71] debugging --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c4d948ce24..87162bdf6d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,7 +155,7 @@ jobs: - run: mkdir .coverage # run tests with just `pydantic-ai-slim` dependencies - - run: uv run --package pydantic-ai-slim coverage run -m pytest -n auto --dist=loadgroup + - run: uv run --package pydantic-ai-slim coverage run -m pytest -n auto --dist=loadgroup -vv env: COVERAGE_FILE: .coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim From 02fe793d75510ee60accb310c0d9fcbc27203b3b Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 21:56:22 +0200 Subject: [PATCH 23/71] debugging --- .github/workflows/ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 87162bdf6d..f8f0adb26d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,7 +155,9 @@ jobs: - run: mkdir .coverage # run tests with just `pydantic-ai-slim` dependencies - - run: uv run --package pydantic-ai-slim coverage run -m pytest -n auto --dist=loadgroup -vv + - run: | + uv pip install pytest-profiling + uv run --package pydantic-ai-slim pytest --profile env: COVERAGE_FILE: .coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim From 78416fa2b35d27addfebe9b0dfffe0980fc219bc Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 21:58:21 +0200 Subject: [PATCH 24/71] debugging --- .github/workflows/ci.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f8f0adb26d..3efa72bd29 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -156,8 +156,7 @@ jobs: # run tests with just `pydantic-ai-slim` dependencies - run: | - uv pip install pytest-profiling - uv run --package pydantic-ai-slim pytest --profile + uv run --with pytest-profiling --package pydantic-ai-slim pytest --profile env: COVERAGE_FILE: .coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim From 76ccb517ba222f10ffde5143b64da929a6f347c7 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 8 Aug 2025 22:05:09 +0200 Subject: [PATCH 25/71] debugging --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3efa72bd29..4bef23988a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -146,7 +146,7 @@ jobs: - uses: astral-sh/setup-uv@v5 with: - enable-cache: true + enable-cache: false - uses: denoland/setup-deno@v2 with: From 547c126e76d3e7b5165211b305d80a09c2a4474c Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 11 Aug 2025 16:33:36 +0200 Subject: [PATCH 26/71] revert ci.yml --- .github/workflows/ci.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4bef23988a..c4d948ce24 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -146,7 +146,7 @@ jobs: - uses: astral-sh/setup-uv@v5 with: - enable-cache: false + enable-cache: true - uses: denoland/setup-deno@v2 with: @@ -155,8 +155,7 @@ jobs: - run: mkdir .coverage # run tests with just `pydantic-ai-slim` dependencies - - run: | - uv run --with pytest-profiling --package pydantic-ai-slim pytest --profile + - run: uv run --package pydantic-ai-slim coverage run -m pytest -n auto --dist=loadgroup env: COVERAGE_FILE: .coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim From e2452150761b9149e4ec9c241da303503675839c Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 11 Aug 2025 16:34:47 +0200 Subject: [PATCH 27/71] debugging --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d20a6780b..dadf4a0955 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -118,7 +118,7 @@ jobs: --extra anthropic --extra mistral --extra cohere - pytest tests/test_live.py -v + pytest --durations=0 tests/test_live.py -v --durations=100 env: PYDANTIC_AI_LIVE_TEST_DANGEROUS: "CHARGE-ME!" @@ -163,7 +163,7 @@ jobs: - run: mkdir .coverage - run: uv sync --group dev - - run: uv run ${{ steps.install-command.outputs.install-command }} coverage run -m pytest -n auto --dist=loadgroup + - run: uv run ${{ steps.install-command.outputs.install-command }} coverage run -m pytest --durations=0 -n auto --dist=loadgroup env: COVERAGE_FILE: .coverage/.coverage.${{ matrix.python-version }}-${{ matrix.install }} @@ -204,7 +204,7 @@ jobs: - run: unset UV_FROZEN - - run: uv run --all-extras --resolution lowest-direct coverage run -m pytest -n auto --dist=loadgroup + - run: uv run --all-extras --resolution lowest-direct coverage run -m pytest --durations=0 -n auto --dist=loadgroup env: COVERAGE_FILE: .coverage/.coverage.${{matrix.python-version}}-lowest-versions @@ -284,7 +284,7 @@ jobs: - run: make lint-js - - run: uv run --package mcp-run-python pytest mcp-run-python -v --durations=100 + - run: uv run --package mcp-run-python pytest --durations=0 mcp-run-python -v --durations=100 - run: deno task dev warmup working-directory: mcp-run-python From 05020aa9d3d006eb8ba8028e994daf3fc7d22bb9 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 11 Aug 2025 16:38:52 +0200 Subject: [PATCH 28/71] fix --- tests/test_history_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 3f1cb49abc..1aa138935e 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -101,7 +101,7 @@ def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage] ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], - usage=Usage(requests=1, request_tokens=54, response_tokens=2, total_tokens=56), + usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), @@ -135,7 +135,7 @@ def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage] ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='hello')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), From 6dfe67638930662d2ed81eb92b51974fd6e995e7 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 12 Aug 2025 17:01:46 +0200 Subject: [PATCH 29/71] update genai-prices --- pydantic_ai_slim/pydantic_ai/messages.py | 6 +++--- pydantic_ai_slim/pyproject.toml | 2 +- uv.lock | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 915fa956f7..59dfaa7917 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -10,7 +10,7 @@ import pydantic import pydantic_core -from genai_prices import calc_price_sync, types as genai_types +from genai_prices import calc_price, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] from typing_extensions import TypeAlias, deprecated @@ -863,9 +863,9 @@ class ModelResponse: """request ID as specified by the model provider. This can be used to track the specific request to the model.""" def price(self) -> genai_types.PriceCalculation: - """Calculate the price of the usage, this doesn't use `auto_update` so won't make any network requests.""" + """Calculate the price of the usage.""" assert self.model_name, 'Model name is required to calculate price' - return calc_price_sync( + return calc_price( self.usage, self.model_name, provider_id=self.provider_name, diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index c056419d03..904967bf59 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ "exceptiongroup; python_version < '3.11'", "opentelemetry-api>=1.28.0", "typing-inspection>=0.4.0", - "genai-prices>=0.0.3", + "genai-prices>=0.0.22", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/uv.lock b/uv.lock index cf935e9c71..22dcadc8b5 100644 --- a/uv.lock +++ b/uv.lock @@ -1289,16 +1289,16 @@ http = [ [[package]] name = "genai-prices" -version = "0.0.3" +version = "0.0.22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "eval-type-backport", marker = "python_full_version < '3.11'" }, { name = "httpx" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4f/06/398237fceebab444e8f91c974cf88e138e3a9b43d98835df3d2997751ef7/genai_prices-0.0.3.tar.gz", hash = "sha256:9a6a11f64d51e825223613f40dd4ed3312d4b5b3ddec8c030eb814e52ea0f54d", size = 40002, upload-time = "2025-07-13T03:29:42.32Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/c5/0aa155ac23a17eb6de36f0611d8595fc49861bdb0a5f302133b7b4f68f5b/genai_prices-0.0.22.tar.gz", hash = "sha256:5e743424d40176ea04de7b74d1ad3a41801390439f4404d4593b82218f2c0c04", size = 44125, upload-time = "2025-08-12T12:04:52.265Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/ee/7266e6b3b19c9544af4069b94301df2c484fd65e8c1bad4334c9a9b9e1c8/genai_prices-0.0.3-py3-none-any.whl", hash = "sha256:dcae66dec82ccc609e2d4365c3b8b6f192115a39f378d9e5fd5728f21edad80d", size = 41959, upload-time = "2025-07-13T03:29:41.157Z" }, + { url = "https://files.pythonhosted.org/packages/8d/1c/313541ea19144a7e5b0ac32dec3dd026de719c5b15147b043d52006dbef7/genai_prices-0.0.22-py3-none-any.whl", hash = "sha256:1ae496bdf517047bc489421c1eff653872e5d456d5eb86d113dfb49d2977a041", size = 46445, upload-time = "2025-08-12T12:04:50.884Z" }, ] [[package]] @@ -3521,7 +3521,7 @@ requires-dist = [ { name = "eval-type-backport", specifier = ">=0.2.0" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, - { name = "genai-prices", specifier = ">=0.0.3" }, + { name = "genai-prices", specifier = ">=0.0.22" }, { name = "google-auth", marker = "extra == 'vertexai'", specifier = ">=2.36.0" }, { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.28.0" }, { name = "griffe", specifier = ">=1.3.2" }, From 47c61a3e15fafa9e4bea819542d7fae53dfbf6b7 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 12 Aug 2025 17:08:23 +0200 Subject: [PATCH 30/71] tests --- tests/models/test_anthropic.py | 4 ++++ tests/test_usage_limits.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 21597a5047..c37a9d971d 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from datetime import timezone +from decimal import Decimal from functools import cached_property from typing import Any, Callable, TypeVar, Union, cast @@ -245,6 +246,9 @@ async def test_async_request_prompt_caching(allow_model_requests: None): }, ) ) + last_message = result.all_messages()[-1] + assert isinstance(last_message, ModelResponse) + assert last_message.price().total_price == snapshot(Decimal('0.0000304')) async def test_async_request_text_response(allow_model_requests: None): diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 58808750b9..d8f77c44d6 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -190,3 +190,9 @@ def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: result = await controller_agent.run('foobar') assert result.output == snapshot('{"delegate_to_other_agent":0}') assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16)) + + +def test_usage_basics(): + usage = RequestUsage() + assert usage.output_audio_tokens is None + assert usage.requests == 1 From 75ab2c4a368654ddb96af7510f4f976d9a0f8991 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 12 Aug 2025 17:11:14 +0200 Subject: [PATCH 31/71] tests --- tests/test_usage_limits.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index d8f77c44d6..2a68207736 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -192,7 +192,30 @@ def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16)) -def test_usage_basics(): +def test_request_usage_basics(): usage = RequestUsage() assert usage.output_audio_tokens is None assert usage.requests == 1 + + +def test_add_usages(): + usage = RunUsage( + requests=2, + input_tokens=10, + output_tokens=20, + cache_read_tokens=30, + cache_write_tokens=40, + input_audio_tokens=50, + cache_audio_read_tokens=60, + ) + assert usage + usage == snapshot( + RunUsage( + requests=4, + input_tokens=20, + output_tokens=40, + cache_write_tokens=80, + cache_read_tokens=60, + input_audio_tokens=100, + cache_audio_read_tokens=120, + ) + ) From 2902e8bc3ba4eec552077c2feee133313c73d2cf Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 12 Aug 2025 17:12:54 +0200 Subject: [PATCH 32/71] tests --- tests/test_usage_limits.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 2a68207736..e21899febf 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -207,6 +207,10 @@ def test_add_usages(): cache_write_tokens=40, input_audio_tokens=50, cache_audio_read_tokens=60, + details={ + 'custom1': 10, + 'custom2': 20, + }, ) assert usage + usage == snapshot( RunUsage( @@ -217,5 +221,8 @@ def test_add_usages(): cache_read_tokens=60, input_audio_tokens=100, cache_audio_read_tokens=120, + details={'custom1': 20, 'custom2': 40}, ) ) + assert usage + RunUsage() == usage + assert RunUsage() + RunUsage() == RunUsage() From 6d0538ddc147658356b797fe5449c10780c8aaec Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 12 Aug 2025 17:21:13 +0200 Subject: [PATCH 33/71] pragma --- .../pydantic_evals/otel/_context_in_memory_span_exporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py b/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py index cd83f30a99..05d4d6bff9 100644 --- a/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py +++ b/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py @@ -109,7 +109,7 @@ def get_finished_spans(self, context_id: str | None = None) -> tuple[ReadableSpa def export(self, spans: typing.Sequence[ReadableSpan]) -> SpanExportResult: """Stores a list of spans in memory.""" - if self._stopped: # pragma: no cover + if self._stopped: return SpanExportResult.FAILURE with self._lock: context_id = _EXPORTER_CONTEXT_ID.get() From 2c730e4fd36c799d12717f670db85b6150befd60 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 13:33:25 +0200 Subject: [PATCH 34/71] update example --- pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index fd0883295e..3c49831e71 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -599,9 +599,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) From 1940caa66c923ece430cba72f18c6a6660e54de6 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 13:40:47 +0200 Subject: [PATCH 35/71] Output audio tokens --- pydantic_ai_slim/pydantic_ai/models/openai.py | 1 + pydantic_ai_slim/pydantic_ai/usage.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index ac9878888d..b54d34ccd4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1272,6 +1272,7 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R ) if response_usage.completion_tokens_details is not None: details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) + u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens if response_usage.prompt_tokens_details is not None: u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index fea2c18e7e..996adbe314 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -65,14 +65,12 @@ class RequestUsage(UsageBase): """Number of audio input tokens.""" cache_audio_read_tokens: int | None = None """Number of audio tokens read from the cache.""" + output_audio_tokens: int | None = None + """Number of audio output tokens.""" details: dict[str, int] | None = None """Any extra details returned by the model.""" - @property - def output_audio_tokens(self): - return None - @property def requests(self): return 1 From d05de25111cb80354f9a26d0b44578778ee0ab67 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 14:03:32 +0200 Subject: [PATCH 36/71] anthropic --- pydantic_ai_slim/pydantic_ai/models/anthropic.py | 10 +++++----- tests/models/test_anthropic.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index cae893bddc..a8d39c7b63 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -553,14 +553,14 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get` # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens # This approach maintains request_tokens as the count of all input tokens, with cached counts as details - request_tokens = ( - details.get('input_tokens', 0) - + details.get('cache_creation_input_tokens', 0) - + details.get('cache_read_input_tokens', 0) - ) + cache_write_tokens = details.get('cache_creation_input_tokens', 0) + cache_read_tokens = details.get('cache_read_input_tokens', 0) + request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens return usage.RequestUsage( input_tokens=request_tokens or None, + cache_read_tokens=cache_read_tokens or None, + cache_write_tokens=cache_write_tokens or None, output_tokens=response_usage.output_tokens, details=details or None, ) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index c37a9d971d..3a4e1c1608 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -237,6 +237,8 @@ async def test_async_request_prompt_caching(allow_model_requests: None): RunUsage( requests=1, input_tokens=13, + cache_write_tokens=4, + cache_read_tokens=6, output_tokens=5, details={ 'input_tokens': 3, @@ -248,7 +250,7 @@ async def test_async_request_prompt_caching(allow_model_requests: None): ) last_message = result.all_messages()[-1] assert isinstance(last_message, ModelResponse) - assert last_message.price().total_price == snapshot(Decimal('0.0000304')) + assert last_message.price().total_price == snapshot(Decimal('0.00003488')) async def test_async_request_text_response(allow_model_requests: None): @@ -1198,6 +1200,8 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: snapshot( RequestUsage( input_tokens=6, + cache_write_tokens=2, + cache_read_tokens=3, output_tokens=1, details={ 'cache_creation_input_tokens': 2, From ee25916b1733c210a1cb57a8017e3ae443cab265 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 14:19:39 +0200 Subject: [PATCH 37/71] gemini audio --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 6aa0a31a9a..c3d2e8dbe8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -872,16 +872,31 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage: if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover + input_audio_tokens = None + output_audio_tokens = None + cache_audio_read_tokens = None for key, metadata_details in metadata.items(): if key.endswith('_details') and metadata_details: metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) suffix = key.removesuffix('_details') for detail in metadata_details: - details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0) + modality = detail['modality'] + details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0) + if value and modality == 'AUDIO': + if key == 'prompt_tokens_details': + input_audio_tokens = value + elif key == 'candidates_tokens_details': + output_audio_tokens = value + elif key == 'cache_tokens_details': + cache_audio_read_tokens = value return usage.RequestUsage( input_tokens=metadata.get('prompt_token_count', 0), output_tokens=metadata.get('candidates_token_count', 0), + cache_read_tokens=cached_content_token_count, + input_audio_tokens=input_audio_tokens, + output_audio_tokens=output_audio_tokens, + cache_audio_read_tokens=cache_audio_read_tokens, details=details, ) From 7e9367e77f825f21913b2544beb474199ed81d84 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 15:29:47 +0200 Subject: [PATCH 38/71] test gemini audio --- tests/models/test_gemini.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index f0c3a337f5..5b7c1329cb 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -46,6 +46,7 @@ _GeminiFunctionCall, _GeminiFunctionCallingConfig, _GeminiFunctionCallPart, + _GeminiModalityTokenCount, _GeminiResponse, _GeminiSafetyRating, _GeminiTextPart, @@ -53,6 +54,7 @@ _GeminiToolConfig, _GeminiTools, _GeminiUsageMetaData, + _metadata_as_usage, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.providers.google_gla import GoogleGLAProvider @@ -2150,3 +2152,35 @@ class CountryLanguage(BaseModel): ), ] ) + + +def test_map_usage(): + response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')]))) + assert 'usage_metadata' in response + response['usage_metadata']['cached_content_token_count'] = 9100 + response['usage_metadata']['prompt_tokens_details'] = [ + _GeminiModalityTokenCount(modality='AUDIO', token_count=9200) + ] + response['usage_metadata']['cache_tokens_details'] = [ + _GeminiModalityTokenCount(modality='AUDIO', token_count=9300), + ] + response['usage_metadata']['candidates_tokens_details'] = [ + _GeminiModalityTokenCount(modality='AUDIO', token_count=9400) + ] + + assert _metadata_as_usage(response) == snapshot( + RequestUsage( + input_tokens=1, + cache_read_tokens=9100, + output_tokens=2, + input_audio_tokens=9200, + cache_audio_read_tokens=9300, + output_audio_tokens=9400, + details={ + 'cached_content_tokens': 9100, + 'audio_prompt_tokens': 9200, + 'audio_cache_tokens': 9300, + 'audio_candidates_tokens': 9400, + }, + ) + ) From 975489966cc6b6c6869398453d50bd2e78bd0604 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 15:43:37 +0200 Subject: [PATCH 39/71] google audio --- .../pydantic_ai/models/_google_common.py | 85 +++++++++++++++++++ pydantic_ai_slim/pydantic_ai/models/gemini.py | 78 +---------------- pydantic_ai_slim/pydantic_ai/models/google.py | 26 +----- tests/models/test_gemini.py | 8 +- 4 files changed, 95 insertions(+), 102 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/models/_google_common.py diff --git a/pydantic_ai_slim/pydantic_ai/models/_google_common.py b/pydantic_ai_slim/pydantic_ai/models/_google_common.py new file mode 100644 index 0000000000..ec0bd5f4b3 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/_google_common.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import Annotated, Literal, cast + +import pydantic +from typing_extensions import NotRequired, TypedDict + +from pydantic_ai import usage + + +class _GeminiModalityTokenCount(TypedDict): + """See .""" + + modality: Annotated[ + Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality') + ] + token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)] + + +class GeminiUsageMetaData(TypedDict, total=False): + """See . + + The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. + """ + + prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] + candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] + total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] + cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] + thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]] + tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]] + prompt_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')] + ] + cache_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')] + ] + candidates_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')] + ] + tool_use_prompt_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')] + ] + + +def metadata_as_request_usage(metadata: GeminiUsageMetaData | None) -> usage.RequestUsage: + if metadata is None: + return usage.RequestUsage() # pragma: no cover + details: dict[str, int] = {} + if cached_content_token_count := metadata.get('cached_content_token_count'): + details['cached_content_tokens'] = cached_content_token_count # pragma: no cover + + if thoughts_token_count := metadata.get('thoughts_token_count'): + details['thoughts_tokens'] = thoughts_token_count + + if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover + + input_audio_tokens = None + output_audio_tokens = None + cache_audio_read_tokens = None + for key, metadata_details in metadata.items(): + if key.endswith('_details') and metadata_details: + metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) + suffix = key.removesuffix('_details') + for detail in metadata_details: + modality = detail['modality'] + details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0) + if value and modality == 'AUDIO': + if key == 'prompt_tokens_details': + input_audio_tokens = value + elif key == 'candidates_tokens_details': + output_audio_tokens = value + elif key == 'cache_tokens_details': + cache_audio_read_tokens = value + + return usage.RequestUsage( + input_tokens=metadata.get('prompt_token_count', 0), + output_tokens=metadata.get('candidates_token_count', 0), + cache_read_tokens=cached_content_token_count, + input_audio_tokens=input_audio_tokens, + output_audio_tokens=output_audio_tokens, + cache_audio_read_tokens=cache_audio_read_tokens, + details=details, + ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c3d2e8dbe8..4d3c6743f0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -48,6 +48,7 @@ download_item, get_user_agent, ) +from ._google_common import GeminiUsageMetaData, metadata_as_request_usage LatestGeminiModelNames = Literal[ 'gemini-2.0-flash', @@ -803,7 +804,7 @@ class _GeminiResponse(TypedDict): candidates: list[_GeminiCandidates] # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response - usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] + usage_metadata: NotRequired[Annotated[GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]] vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] @@ -823,82 +824,9 @@ class _GeminiCandidates(TypedDict): safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]] -class _GeminiModalityTokenCount(TypedDict): - """See .""" - - modality: Annotated[ - Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality') - ] - token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)] - - -class _GeminiUsageMetaData(TypedDict, total=False): - """See . - - The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. - """ - - prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] - candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] - total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] - cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] - thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]] - tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]] - prompt_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')] - ] - cache_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')] - ] - candidates_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')] - ] - tool_use_prompt_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')] - ] - - def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage: metadata = response.get('usage_metadata') - if metadata is None: - return usage.RequestUsage() # pragma: no cover - details: dict[str, int] = {} - if cached_content_token_count := metadata.get('cached_content_token_count'): - details['cached_content_tokens'] = cached_content_token_count # pragma: no cover - - if thoughts_token_count := metadata.get('thoughts_token_count'): - details['thoughts_tokens'] = thoughts_token_count - - if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): - details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover - - input_audio_tokens = None - output_audio_tokens = None - cache_audio_read_tokens = None - for key, metadata_details in metadata.items(): - if key.endswith('_details') and metadata_details: - metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) - suffix = key.removesuffix('_details') - for detail in metadata_details: - modality = detail['modality'] - details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0) - if value and modality == 'AUDIO': - if key == 'prompt_tokens_details': - input_audio_tokens = value - elif key == 'candidates_tokens_details': - output_audio_tokens = value - elif key == 'cache_tokens_details': - cache_audio_read_tokens = value - - return usage.RequestUsage( - input_tokens=metadata.get('prompt_token_count', 0), - output_tokens=metadata.get('candidates_token_count', 0), - cache_read_tokens=cached_content_token_count, - input_audio_tokens=input_audio_tokens, - output_audio_tokens=output_audio_tokens, - cache_audio_read_tokens=cache_audio_read_tokens, - details=details, - ) + return metadata_as_request_usage(metadata) class _GeminiSafetyRating(TypedDict): diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 632c3f1f8b..8271682e1c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -46,6 +46,7 @@ download_item, get_user_agent, ) +from ._google_common import GeminiUsageMetaData, metadata_as_request_usage try: from google.genai import Client @@ -594,26 +595,5 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage: metadata = response.usage_metadata if metadata is None: return usage.RequestUsage() # pragma: no cover - metadata = metadata.model_dump(exclude_defaults=True) - - details: dict[str, int] = {} - if cached_content_token_count := metadata.get('cached_content_token_count'): - details['cached_content_tokens'] = cached_content_token_count # pragma: no cover - - if thoughts_token_count := metadata.get('thoughts_token_count'): - details['thoughts_tokens'] = thoughts_token_count - - if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): - details['tool_use_prompt_tokens'] = tool_use_prompt_token_count - - for key, metadata_details in metadata.items(): - if key.endswith('_details') and metadata_details: - suffix = key.removesuffix('_details') - for detail in metadata_details: - details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0) - - return usage.RequestUsage( - input_tokens=metadata.get('prompt_token_count', 0), - output_tokens=metadata.get('candidates_token_count', 0), - details=details, - ) + metadata = cast(GeminiUsageMetaData, metadata.model_dump(exclude_defaults=True)) + return metadata_as_request_usage(metadata) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 5b7c1329cb..df50e5debb 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -34,9 +34,11 @@ VideoUrl, ) from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.models._google_common import _GeminiModalityTokenCount from pydantic_ai.models.gemini import ( GeminiModel, GeminiModelSettings, + GeminiUsageMetaData, _content_model_response, _gemini_response_ta, _gemini_streamed_response_ta, @@ -46,14 +48,12 @@ _GeminiFunctionCall, _GeminiFunctionCallingConfig, _GeminiFunctionCallPart, - _GeminiModalityTokenCount, _GeminiResponse, _GeminiSafetyRating, _GeminiTextPart, _GeminiThoughtPart, _GeminiToolConfig, _GeminiTools, - _GeminiUsageMetaData, _metadata_as_usage, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput @@ -605,8 +605,8 @@ def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | No return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123') -def example_usage() -> _GeminiUsageMetaData: - return _GeminiUsageMetaData(prompt_token_count=1, candidates_token_count=2, total_token_count=3) +def example_usage() -> GeminiUsageMetaData: + return GeminiUsageMetaData(prompt_token_count=1, candidates_token_count=2, total_token_count=3) async def test_text_success(get_gemini_client: GetGeminiClient): From 0b64f961c8e501f6be0ff50ce80a69aeb0778b80 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 15:47:39 +0200 Subject: [PATCH 40/71] update tests --- tests/models/test_openai.py | 20 ++++++++++++++++++++ tests/test_agent.py | 2 ++ tests/test_mcp.py | 4 ++++ 3 files changed, 26 insertions(+) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 3059c6d3e8..38acd40be8 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -824,6 +824,7 @@ async def get_image() -> ImageUrl: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -861,6 +862,7 @@ async def get_image() -> ImageUrl: cache_read_tokens=0, output_tokens=8, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -904,6 +906,7 @@ async def get_image() -> BinaryContent: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -939,6 +942,7 @@ async def get_image() -> BinaryContent: cache_read_tokens=0, output_tokens=9, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1824,6 +1828,7 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s cache_read_tokens=0, output_tokens=8, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1870,6 +1875,7 @@ async def get_temperature(city: str) -> float: cache_read_tokens=0, output_tokens=15, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1896,6 +1902,7 @@ async def get_temperature(city: str) -> float: cache_read_tokens=0, output_tokens=15, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2052,6 +2059,7 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api cache_read_tokens=0, output_tokens=2437, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2353,6 +2361,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2387,6 +2396,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=36, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2446,6 +2456,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2474,6 +2485,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=10, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2526,6 +2538,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2554,6 +2567,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=15, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2608,6 +2622,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2640,6 +2655,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=25, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2697,6 +2713,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2732,6 +2749,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2793,6 +2811,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2832,6 +2851,7 @@ async def get_user_country() -> str: cache_read_tokens=0, output_tokens=21, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, diff --git a/tests/test_agent.py b/tests/test_agent.py index 68d03ca235..dae13c4755 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2953,6 +2953,7 @@ def test_binary_content_serializable(): 'output_tokens': 4, 'input_audio_tokens': None, 'cache_audio_read_tokens': None, + 'output_audio_tokens': None, 'details': None, }, 'model_name': 'test', @@ -3007,6 +3008,7 @@ def test_image_url_serializable(): 'output_tokens': 4, 'input_audio_tokens': None, 'cache_audio_read_tokens': None, + 'output_audio_tokens': None, 'details': None, }, 'model_name': 'test', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 4b71b06c31..236df15388 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -486,6 +486,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -514,6 +515,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -893,6 +895,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -921,6 +924,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, From f5b429799b4143f92de5326eafe94d1cbeaa4d75 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 16:14:37 +0200 Subject: [PATCH 41/71] add pip for pycharm --- pyproject.toml | 1 + uv.lock | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3006f8225d..df300fdf32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ dev = [ "boto3-stubs[bedrock-runtime]", "strict-no-cover @ git+https://github.com/pydantic/strict-no-cover.git@7fc59da2c4dff919db2095a0f0e47101b657131d", "pytest-xdist>=3.6.1", + "pip > 0", ] lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"] docs = [ diff --git a/uv.lock b/uv.lock index 4701d7901a..3d20f04ebb 100644 --- a/uv.lock +++ b/uv.lock @@ -3039,6 +3039,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/37/ae/2dbfc38cc4fd14aceea14bc440d5151b21f64c4c3ba3f6f4191610b7ee5d/pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3", size = 2554652, upload-time = "2024-07-01T09:48:38.789Z" }, ] +[[package]] +name = "pip" +version = "25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/16/650289cd3f43d5a2fadfd98c68bd1e1e7f2550a1a5326768cddfbcedb2c5/pip-25.2.tar.gz", hash = "sha256:578283f006390f85bb6282dffb876454593d637f5d1be494b5202ce4877e71f2", size = 1840021, upload-time = "2025-07-30T21:50:15.401Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/3f/945ef7ab14dc4f9d7f40288d2df998d1837ee0888ec3659c813487572faa/pip-25.2-py3-none-any.whl", hash = "sha256:6d67a2b4e7f14d8b31b8b52648866fa717f45a1eb70e83002f4331d07e953717", size = 1752557, upload-time = "2025-07-30T21:50:13.323Z" }, +] + [[package]] name = "platformdirs" version = "4.3.6" @@ -3331,6 +3340,7 @@ dev = [ { name = "dirty-equals" }, { name = "duckduckgo-search" }, { name = "inline-snapshot" }, + { name = "pip" }, { name = "pytest" }, { name = "pytest-examples" }, { name = "pytest-mock" }, @@ -3379,6 +3389,7 @@ dev = [ { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "duckduckgo-search", specifier = ">=7.0.0" }, { name = "inline-snapshot", specifier = ">=0.19.3" }, + { name = "pip", specifier = ">0" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-examples", specifier = ">=0.0.18" }, { name = "pytest-mock", specifier = ">=3.14.0" }, @@ -3556,8 +3567,8 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, - { name = "tenacity", marker = "extra == 'retries'", specifier = ">=8.2.3" }, { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.15.0" }, + { name = "tenacity", marker = "extra == 'retries'", specifier = ">=8.2.3" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "tavily", "temporal", "vertexai"] From df73d6f0e3cdc802dfbd880d536cb4cc04f81b98 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 16:17:59 +0200 Subject: [PATCH 42/71] durations --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index de5fe6c8fe..3e59383924 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -118,7 +118,7 @@ jobs: --extra anthropic --extra mistral --extra cohere - pytest --durations=0 tests/test_live.py -v + pytest tests/test_live.py -v --durations=100 env: PYDANTIC_AI_LIVE_TEST_DANGEROUS: "CHARGE-ME!" @@ -163,7 +163,7 @@ jobs: - run: mkdir .coverage - run: uv sync --only-dev - - run: uv run ${{ matrix.install.command }} coverage run -m pytest -n auto --dist=loadgroup + - run: uv run ${{ matrix.install.command }} coverage run -m pytest --durations=0 -n auto --dist=loadgroup env: COVERAGE_FILE: .coverage/.coverage.${{ matrix.python-version }}-${{ matrix.install.name }} @@ -284,7 +284,7 @@ jobs: - run: make lint-js - - run: uv run --package mcp-run-python pytest --durations=0 mcp-run-python -v --durations=100 + - run: uv run --package mcp-run-python pytest mcp-run-python -v --durations=100 - run: deno task dev warmup working-directory: mcp-run-python From e09e5f7f00eaf5dab06ab3dd401a9f9ebb6f5ab6 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 16:23:34 +0200 Subject: [PATCH 43/71] debugging --- pydantic_ai_slim/pydantic_ai/messages.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 59dfaa7917..62a88ecafd 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -10,7 +10,6 @@ import pydantic import pydantic_core -from genai_prices import calc_price, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] from typing_extensions import TypeAlias, deprecated @@ -862,8 +861,12 @@ class ModelResponse: provider_request_id: str | None = None """request ID as specified by the model provider. This can be used to track the specific request to the model.""" - def price(self) -> genai_types.PriceCalculation: + def price(self): """Calculate the price of the usage.""" + from genai_prices import calc_price + + 1 / 0 + assert self.model_name, 'Model name is required to calculate price' return calc_price( self.usage, From 67d1bcfadec9f1f02c555416ada44cd95d30a898 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 16:42:54 +0200 Subject: [PATCH 44/71] Revert "debugging" This reverts commit e09e5f7f00eaf5dab06ab3dd401a9f9ebb6f5ab6. --- pydantic_ai_slim/pydantic_ai/messages.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 62a88ecafd..59dfaa7917 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -10,6 +10,7 @@ import pydantic import pydantic_core +from genai_prices import calc_price, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] from typing_extensions import TypeAlias, deprecated @@ -861,12 +862,8 @@ class ModelResponse: provider_request_id: str | None = None """request ID as specified by the model provider. This can be used to track the specific request to the model.""" - def price(self): + def price(self) -> genai_types.PriceCalculation: """Calculate the price of the usage.""" - from genai_prices import calc_price - - 1 / 0 - assert self.model_name, 'Model name is required to calculate price' return calc_price( self.usage, From e711281cf47b9b5443f6e8f8062b9a8d71d5151d Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 18:37:59 +0200 Subject: [PATCH 45/71] revert changes moved to other PR --- .../otel/_context_in_memory_span_exporter.py | 2 +- tests/conftest.py | 19 ------------------- tests/evals/test_otel.py | 7 +++++++ 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py b/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py index 05d4d6bff9..cd83f30a99 100644 --- a/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py +++ b/pydantic_evals/pydantic_evals/otel/_context_in_memory_span_exporter.py @@ -109,7 +109,7 @@ def get_finished_spans(self, context_id: str | None = None) -> tuple[ReadableSpa def export(self, spans: typing.Sequence[ReadableSpan]) -> SpanExportResult: """Stores a list of spans in memory.""" - if self._stopped: + if self._stopped: # pragma: no cover return SpanExportResult.FAILURE with self._lock: context_id = _EXPORTER_CONTEXT_ID.get() diff --git a/tests/conftest.py b/tests/conftest.py index 26a6ad039d..f37d35a946 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,6 @@ from vcr import VCR, request as vcr_request import pydantic_ai.models -from pydantic_ai import Agent from pydantic_ai.messages import BinaryContent from pydantic_ai.models import Model, cached_async_http_client @@ -231,24 +230,6 @@ def event_loop() -> Iterator[None]: new_loop.close() -@pytest.fixture(autouse=True) -def no_instrumentation_by_default(): - Agent.instrument_all(False) - - -try: - import logfire - - logfire.DEFAULT_LOGFIRE_INSTANCE.config.ignore_no_config = True - - @pytest.fixture(autouse=True) - def fresh_logfire(): - logfire.shutdown(flush=False) - -except ImportError: - pass - - def raise_if_exception(e: Any) -> None: if isinstance(e, Exception): raise e diff --git a/tests/evals/test_otel.py b/tests/evals/test_otel.py index 08e035dc32..10fec9f7bb 100644 --- a/tests/evals/test_otel.py +++ b/tests/evals/test_otel.py @@ -336,6 +336,8 @@ async def test_span_node_repr(span_tree: SpanTree): async def test_span_tree_ancestors_methods(): """Test the ancestor traversal methods in SpanNode.""" + # Configure logfire + logfire.configure() # Create spans with a deep structure for testing ancestor methods with context_subtree() as tree: @@ -396,6 +398,8 @@ async def test_span_tree_ancestors_methods(): async def test_span_tree_descendants_methods(): """Test the descendant traversal methods in SpanNode.""" + # Configure logfire + logfire.configure() # Create spans with a deep structure for testing descendant methods with context_subtree() as tree: @@ -484,6 +488,8 @@ async def test_span_tree_descendants_methods(): async def test_log_levels_and_exceptions(): """Test recording different log levels and exceptions in spans.""" + # Configure logfire + logfire.configure() with context_subtree() as tree: # Test different log levels @@ -885,6 +891,7 @@ async def test_context_subtree_not_configured(mocker: MockerFixture): """Test that context_subtree correctly records spans in independent async contexts.""" from opentelemetry.trace import ProxyTracerProvider + # from opentelemetry.sdk.trace import TracerProvider mocker.patch( 'pydantic_evals.otel._context_in_memory_span_exporter.get_tracer_provider', return_value=ProxyTracerProvider() ) From d8814221ec12dc93d3df16af199b77359417a77b Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 18:41:55 +0200 Subject: [PATCH 46/71] fix --- tests/test_mcp.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 236df15388..30e8e6b076 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -206,6 +206,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) cache_read_tokens=0, output_tokens=19, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -234,6 +235,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) cache_read_tokens=0, output_tokens=13, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -340,6 +342,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=18, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -372,6 +375,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=19, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -415,6 +419,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -443,6 +448,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -561,6 +567,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -594,6 +601,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: cache_read_tokens=0, output_tokens=19, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -641,6 +649,7 @@ async def test_tool_returning_image_resource_link( cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -674,6 +683,7 @@ async def test_tool_returning_image_resource_link( cache_read_tokens=0, output_tokens=29, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -727,6 +737,7 @@ async def test_tool_returning_audio_resource( usage=RequestUsage( input_tokens=575, output_tokens=15, + input_audio_tokens=144, details={'text_prompt_tokens': 431, 'audio_prompt_tokens': 144}, ), model_name='models/gemini-2.5-pro-preview-05-06', @@ -785,6 +796,7 @@ async def test_tool_returning_audio_resource_link( usage=RequestUsage( input_tokens=784, output_tokens=5, + input_audio_tokens=144, details={'text_prompt_tokens': 640, 'audio_prompt_tokens': 144}, ), model_name='models/gemini-2.5-pro', @@ -823,6 +835,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -858,6 +871,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im cache_read_tokens=0, output_tokens=15, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -970,6 +984,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=15, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1004,6 +1019,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=15, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1036,6 +1052,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=22, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1073,6 +1090,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1101,6 +1119,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): cache_read_tokens=0, output_tokens=11, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1146,6 +1165,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: cache_read_tokens=0, output_tokens=12, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1190,6 +1210,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: cache_read_tokens=0, output_tokens=24, input_audio_tokens=0, + output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, From 175f952d35917b6fae41d40d579c55215e1c93a5 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 18:54:54 +0200 Subject: [PATCH 47/71] pragma --- pydantic_ai_slim/pydantic_ai/models/_google_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/_google_common.py b/pydantic_ai_slim/pydantic_ai/models/_google_common.py index ec0bd5f4b3..c224da5212 100644 --- a/pydantic_ai_slim/pydantic_ai/models/_google_common.py +++ b/pydantic_ai_slim/pydantic_ai/models/_google_common.py @@ -71,7 +71,7 @@ def metadata_as_request_usage(metadata: GeminiUsageMetaData | None) -> usage.Req input_audio_tokens = value elif key == 'candidates_tokens_details': output_audio_tokens = value - elif key == 'cache_tokens_details': + elif key == 'cache_tokens_details': # pragma: no branch cache_audio_read_tokens = value return usage.RequestUsage( From 645d97d62ab5de7c0bf5c97bb0f6d1458855ad7e Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:04:03 +0200 Subject: [PATCH 48/71] Restore total_tokens, default everything to 0 --- .../pydantic_ai/models/_google_common.py | 12 ++-- .../pydantic_ai/models/anthropic.py | 8 +-- pydantic_ai_slim/pydantic_ai/models/cohere.py | 4 +- .../pydantic_ai/models/huggingface.py | 1 - .../pydantic_ai/models/mistral.py | 1 - pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +- pydantic_ai_slim/pydantic_ai/usage.py | 59 ++++++++----------- tests/models/test_mistral.py | 2 +- tests/test_agent.py | 6 +- tests/test_live.py | 6 +- tests/test_streaming.py | 2 +- tests/test_usage_limits.py | 2 +- 12 files changed, 50 insertions(+), 59 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/_google_common.py b/pydantic_ai_slim/pydantic_ai/models/_google_common.py index c224da5212..cbe5546648 100644 --- a/pydantic_ai_slim/pydantic_ai/models/_google_common.py +++ b/pydantic_ai_slim/pydantic_ai/models/_google_common.py @@ -47,18 +47,18 @@ def metadata_as_request_usage(metadata: GeminiUsageMetaData | None) -> usage.Req if metadata is None: return usage.RequestUsage() # pragma: no cover details: dict[str, int] = {} - if cached_content_token_count := metadata.get('cached_content_token_count'): + if cached_content_token_count := metadata.get('cached_content_token_count', 0): details['cached_content_tokens'] = cached_content_token_count # pragma: no cover - if thoughts_token_count := metadata.get('thoughts_token_count'): + if thoughts_token_count := metadata.get('thoughts_token_count', 0): details['thoughts_tokens'] = thoughts_token_count - if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): + if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0): details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover - input_audio_tokens = None - output_audio_tokens = None - cache_audio_read_tokens = None + input_audio_tokens = 0 + output_audio_tokens = 0 + cache_audio_read_tokens = 0 for key, metadata_details in metadata.items(): if key.endswith('_details') and metadata_details: metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a8d39c7b63..60e06aefad 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -558,11 +558,11 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens return usage.RequestUsage( - input_tokens=request_tokens or None, - cache_read_tokens=cache_read_tokens or None, - cache_write_tokens=cache_write_tokens or None, + input_tokens=request_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, output_tokens=response_usage.output_tokens, - details=details or None, + details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index ba4355b92f..c9ed3a8566 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -316,8 +316,8 @@ def _map_usage(response: V2ChatResponse) -> usage.RequestUsage: if u.billed_units.classifications: # pragma: no cover details['classifications'] = int(u.billed_units.classifications) - request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else None - response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else None + request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0 + response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0 return usage.RequestUsage( input_tokens=request_tokens, output_tokens=response_tokens, diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 104c345afd..a1e9b31a14 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -481,5 +481,4 @@ def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> u return usage.RequestUsage( input_tokens=response_usage.prompt_tokens, output_tokens=response_usage.completion_tokens, - details=None, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 8063918596..7b5558af11 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -708,7 +708,6 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) return RequestUsage( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, - details=None, ) else: return RequestUsage() # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index b54d34ccd4..490d6150b0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1272,8 +1272,8 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R ) if response_usage.completion_tokens_details is not None: details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) - u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens + u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0 if response_usage.prompt_tokens_details is not None: - u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens - u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens + u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0 + u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0 return u diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 996adbe314..160bd73ecc 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -13,9 +13,9 @@ class UsageBase: - input_tokens: int | None = None - output_tokens: int | None = None - details: dict[str, int] | None = None + input_tokens: int + output_tokens: int + details: dict[str, int] def opentelemetry_attributes(self) -> dict[str, int]: """Get the token usage values as OpenTelemetry attributes.""" @@ -50,25 +50,25 @@ class RequestUsage(UsageBase): Prices for LLM requests are calculated using [genai-prices](https://github.com/pydantic/genai-prices). """ - input_tokens: int | None = None - """Number of text input/prompt tokens.""" + input_tokens: int = 0 + """Number of input/prompt tokens.""" - cache_write_tokens: int | None = None + cache_write_tokens: int = 0 """Number of tokens written to the cache.""" - cache_read_tokens: int | None = None + cache_read_tokens: int = 0 """Number of tokens read from the cache.""" - output_tokens: int | None = None - """Number of text output/completion tokens.""" + output_tokens: int = 0 + """Number of output/completion tokens.""" - input_audio_tokens: int | None = None + input_audio_tokens: int = 0 """Number of audio input tokens.""" - cache_audio_read_tokens: int | None = None + cache_audio_read_tokens: int = 0 """Number of audio tokens read from the cache.""" - output_audio_tokens: int | None = None + output_audio_tokens: int = 0 """Number of audio output tokens.""" - details: dict[str, int] | None = None + details: dict[str, int] = dataclasses.field(default_factory=dict) """Any extra details returned by the model.""" @property @@ -120,46 +120,39 @@ class RunUsage(UsageBase): requests: int = 0 """Number of requests made to the LLM API.""" - input_tokens: int | None = None + input_tokens: int = 0 """Total number of text input/prompt tokens.""" - cache_write_tokens: int | None = None + cache_write_tokens: int = 0 """Total number of tokens written to the cache.""" - cache_read_tokens: int | None = None + cache_read_tokens: int = 0 """Total number of tokens read from the cache.""" - input_audio_tokens: int | None = None + input_audio_tokens: int = 0 """Total number of audio input tokens.""" - cache_audio_read_tokens: int | None = None + cache_audio_read_tokens: int = 0 """Total number of audio tokens read from the cache.""" - output_tokens: int | None = None + output_tokens: int = 0 """Total number of text output/completion tokens.""" - details: dict[str, int] | None = None + details: dict[str, int] = dataclasses.field(default_factory=dict) """Any extra details returned by the model.""" - def input_output_tokens(self) -> int | None: - """Sum of `input_tokens + output_tokens`.""" - if self.input_tokens is None and self.output_tokens is None: - return None - else: - return (self.input_tokens or 0) + (self.output_tokens or 0) - @property @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') - def request_tokens(self) -> int | None: + def request_tokens(self) -> int: return self.input_tokens @property @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') - def response_tokens(self) -> int | None: + def response_tokens(self) -> int: return self.output_tokens @property - @deprecated('`total_tokens` is deprecated, sum the specific fields you need or use `input_output_tokens` instead') - def total_tokens(self) -> int | None: - return sum(v for k, v in dataclasses.asdict(self).values() if k.endswith('_tokens') and v is not None) + def total_tokens(self) -> int: + """Sum of `input_tokens + output_tokens`.""" + return self.input_tokens + self.output_tokens def incr(self, incr_usage: RunUsage | RequestUsage) -> None: """Increment the usage in place. @@ -309,7 +302,7 @@ def check_tokens(self, usage: RunUsage) -> None: f'Exceeded the output_tokens_limit of {self.output_tokens_limit} ({output_tokens=})' ) - total_tokens = usage.input_output_tokens() or 0 + total_tokens = usage.total_tokens if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 512a60f1a7..fd1c9fd744 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1063,7 +1063,7 @@ async def get_location(loc_name: str) -> str: assert result.output == 'final response' assert result.usage().input_tokens == 6 assert result.usage().output_tokens == 4 - assert result.usage().input_output_tokens() == 10 + assert result.usage().total_tokens == 10 assert result.all_messages() == snapshot( [ ModelRequest( diff --git a/tests/test_agent.py b/tests/test_agent.py index dae13c4755..0b3bbfb54d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1770,7 +1770,7 @@ async def ret_a(x: str) -> str: assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] assert result2.output == snapshot('{"ret_a":"a-apple"}') assert result2._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] - assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13, details=None)) + assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -1827,7 +1827,7 @@ async def ret_a(x: str) -> str: assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] assert result3.output == snapshot('{"ret_a":"a-apple"}') assert result3._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] - assert result3.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13, details=None)) + assert result3.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) def test_run_with_history_new_structured(): @@ -1954,7 +1954,7 @@ async def ret_a(x: str) -> str: assert result2.output == snapshot(Response(a=0)) assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] - assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13, details=None)) + assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ diff --git a/tests/test_live.py b/tests/test_live.py index b1bb49808d..6c3eba8072 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -118,7 +118,7 @@ async def test_text(http_client: httpx.AsyncClient, tmp_path: Path, get_model: G assert 'paris' in result.output.lower() print('Text usage:', result.usage()) usage = result.usage() - total_tokens = usage.input_output_tokens() + total_tokens = usage.total_tokens assert total_tokens is not None and total_tokens > 0 @@ -135,7 +135,7 @@ async def test_stream(http_client: httpx.AsyncClient, tmp_path: Path, get_model: print('Stream usage:', result.usage()) usage = result.usage() if get_model.__name__ != 'ollama': - total_tokens = usage.input_output_tokens() + total_tokens = usage.total_tokens assert total_tokens is not None and total_tokens > 0 @@ -164,5 +164,5 @@ async def test_structured(http_client: httpx.AsyncClient, tmp_path: Path, get_mo assert result.output.city.lower() == 'london' print('Structured usage:', result.usage()) usage = result.usage() - total_tokens = usage.input_output_tokens() + total_tokens = usage.total_tokens assert total_tokens is not None and total_tokens > 0 diff --git a/tests/test_streaming.py b/tests/test_streaming.py index f0cb178bb8..624dc8399e 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -860,7 +860,7 @@ def output_validator_simple(data: str) -> str: messages.append(chunk) stream_usage = deepcopy(stream.usage()) assert run.next_node == End(data=FinalResult(output='The bat sat on the mat.', tool_name=None, tool_call_id=None)) - assert run.usage() == stream_usage == RunUsage(requests=1, input_tokens=51, output_tokens=7, details=None) + assert run.usage() == stream_usage == RunUsage(requests=1, input_tokens=51, output_tokens=7) assert messages == [ '', diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 97234769a4..a8feb3af37 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -201,7 +201,7 @@ def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: def test_request_usage_basics(): usage = RequestUsage() - assert usage.output_audio_tokens is None + assert usage.output_audio_tokens == 0 assert usage.requests == 1 From aaf69505d9a73f569336d680ba324eb0d16905fa Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:06:07 +0200 Subject: [PATCH 49/71] Remove 0s from snapshots --- tests/models/test_anthropic.py | 4 +- tests/models/test_cohere.py | 8 +-- tests/models/test_deepseek.py | 1 - tests/models/test_gemini.py | 8 +-- tests/models/test_mistral.py | 8 +-- tests/models/test_openai.py | 80 ++------------------------- tests/models/test_openai_responses.py | 72 ++++++------------------ tests/test_mcp.py | 69 ----------------------- 8 files changed, 30 insertions(+), 220 deletions(-) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3a4e1c1608..c970ca126b 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1496,9 +1496,7 @@ async def test_anthropic_server_tool_pass_history_to_another_provider( ModelRequest(parts=[UserPromptPart(content='What day is tomorrow?', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Tomorrow will be **Tuesday, May 27, 2025**.')], - usage=RequestUsage( - input_tokens=410, cache_read_tokens=0, output_tokens=17, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=410, output_tokens=17, details={'reasoning_tokens': 0}), model_name='gpt-4.1-2025-04-14', timestamp=IsDatetime(), provider_request_id='resp_6834631faf2481918638284f62855ddf040b4e5d7e74f261', diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 3437a3d0f0..6dfed3b9a1 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -433,9 +433,7 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=1909, details={'reasoning_tokens': 1472} - ), + usage=RequestUsage(input_tokens=13, output_tokens=1909, details={'reasoning_tokens': 1472}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', @@ -459,9 +457,7 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=1909, details={'reasoning_tokens': 1472} - ), + usage=RequestUsage(input_tokens=13, output_tokens=1909, details={'reasoning_tokens': 1472}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', diff --git a/tests/models/test_deepseek.py b/tests/models/test_deepseek.py index 90dee7623f..b88b588a5e 100644 --- a/tests/models/test_deepseek.py +++ b/tests/models/test_deepseek.py @@ -46,7 +46,6 @@ async def test_deepseek_model_thinking_part(allow_model_requests: None, deepseek parts=[ThinkingPart(content=IsStr()), TextPart(content=IsStr())], usage=RequestUsage( input_tokens=12, - cache_read_tokens=0, output_tokens=789, details={ 'prompt_cache_hit_tokens': 0, diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index df50e5debb..cc097da4a9 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1484,9 +1484,7 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api """ ), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=2028, details={'reasoning_tokens': 1664} - ), + usage=RequestUsage(input_tokens=13, output_tokens=2028, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', @@ -1512,9 +1510,7 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=2028, details={'reasoning_tokens': 1664} - ), + usage=RequestUsage(input_tokens=13, output_tokens=2028, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index fd1c9fd744..582086f768 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -2052,9 +2052,7 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap ThinkingPart(content=IsStr(), id='rs_68079ad7f0588191af64f067e7314d840493b22e4095129c'), TextPart(content=IsStr()), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=1789, details={'reasoning_tokens': 1344} - ), + usage=RequestUsage(input_tokens=13, output_tokens=1789, details={'reasoning_tokens': 1344}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', @@ -2112,9 +2110,7 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap """ ), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=1789, details={'reasoning_tokens': 1344} - ), + usage=RequestUsage(input_tokens=13, output_tokens=1789, details={'reasoning_tokens': 1344}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 38acd40be8..273b53700e 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -821,10 +821,7 @@ async def get_image() -> ImageUrl: parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_4hrT4QP9jfojtK69vGiFCFjG')], usage=RequestUsage( input_tokens=46, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -859,10 +856,7 @@ async def get_image() -> ImageUrl: parts=[TextPart(content='The image shows a potato.')], usage=RequestUsage( input_tokens=503, - cache_read_tokens=0, output_tokens=8, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -903,10 +897,7 @@ async def get_image() -> BinaryContent: parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_Btn0GIzGr4ugNlLmkQghQUMY')], usage=RequestUsage( input_tokens=46, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -939,10 +930,7 @@ async def get_image() -> BinaryContent: parts=[TextPart(content='The image shows a kiwi fruit.')], usage=RequestUsage( input_tokens=1185, - cache_read_tokens=0, output_tokens=9, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1825,10 +1813,7 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s parts=[TextPart(content='The capital of France is Paris.')], usage=RequestUsage( input_tokens=24, - cache_read_tokens=0, output_tokens=8, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1872,10 +1857,7 @@ async def get_temperature(city: str) -> float: parts=[ToolCallPart(tool_name='get_temperature', args='{"city":"Tokyo"}', tool_call_id=IsStr())], usage=RequestUsage( input_tokens=50, - cache_read_tokens=0, output_tokens=15, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1899,10 +1881,7 @@ async def get_temperature(city: str) -> float: parts=[TextPart(content='The temperature in Tokyo is currently 20.0 degrees Celsius.')], usage=RequestUsage( input_tokens=75, - cache_read_tokens=0, output_tokens=15, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1935,9 +1914,7 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034841ab2881918a8c210e3d988b9208c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=2050, details={'reasoning_tokens': 1664} - ), + usage=RequestUsage(input_tokens=13, output_tokens=2050, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', @@ -1960,9 +1937,7 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034841ab2881918a8c210e3d988b9208c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=2050, details={'reasoning_tokens': 1664} - ), + usage=RequestUsage(input_tokens=13, output_tokens=2050, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', @@ -1982,9 +1957,7 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034858dc588191bc3a6801c23e728f08c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=424, cache_read_tokens=0, output_tokens=2033, details={'reasoning_tokens': 1408} - ), + usage=RequestUsage(input_tokens=424, output_tokens=2033, details={'reasoning_tokens': 1408}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_6803484f19a88191b9ea975d7cfbbe8408c845d2be9bcdd8', @@ -2011,9 +1984,7 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=1900, details={'reasoning_tokens': 1536} - ), + usage=RequestUsage(input_tokens=13, output_tokens=1900, details={'reasoning_tokens': 1536}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', @@ -2037,9 +2008,7 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=RequestUsage( - input_tokens=13, cache_read_tokens=0, output_tokens=1900, details={'reasoning_tokens': 1536} - ), + usage=RequestUsage(input_tokens=13, output_tokens=1900, details={'reasoning_tokens': 1536}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), provider_request_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', @@ -2056,10 +2025,7 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api parts=[TextPart(content=IsStr())], usage=RequestUsage( input_tokens=822, - cache_read_tokens=0, output_tokens=2437, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2358,10 +2324,7 @@ async def get_user_country() -> str: parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], usage=RequestUsage( input_tokens=68, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2393,10 +2356,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=89, - cache_read_tokens=0, output_tokens=36, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2453,10 +2413,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=42, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2482,10 +2439,7 @@ async def get_user_country() -> str: parts=[TextPart(content='The largest city in Mexico is Mexico City.')], usage=RequestUsage( input_tokens=63, - cache_read_tokens=0, output_tokens=10, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2535,10 +2489,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=71, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2564,10 +2515,7 @@ async def get_user_country() -> str: parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], usage=RequestUsage( input_tokens=92, - cache_read_tokens=0, output_tokens=15, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2619,10 +2567,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=160, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2652,10 +2597,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=181, - cache_read_tokens=0, output_tokens=25, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2710,10 +2652,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=109, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2746,10 +2685,7 @@ async def get_user_country() -> str: parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], usage=RequestUsage( input_tokens=130, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2808,10 +2744,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=273, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2848,10 +2781,7 @@ async def get_user_country() -> str: ], usage=RequestUsage( input_tokens=294, - cache_read_tokens=0, output_tokens=21, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index ca7be33733..afec56ff3d 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -197,9 +197,7 @@ async def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=RequestUsage( - input_tokens=0, cache_read_tokens=0, output_tokens=0, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=0, output_tokens=0, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_67e547c48c9481918c5c4394464ce0c60ae6111e84dd5c08', @@ -230,9 +228,7 @@ async def get_location(loc_name: str) -> str: """ ) ], - usage=RequestUsage( - input_tokens=335, cache_read_tokens=0, output_tokens=44, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=335, output_tokens=44, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_67e547c5a2f08191802a1f43620f348503a2086afed73b47', @@ -265,9 +261,7 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id=IsStr())], - usage=RequestUsage( - input_tokens=40, cache_read_tokens=0, output_tokens=11, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=40, output_tokens=11, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_681134d3aa3481919ca581a267db1e510fe7a5a4e2123dc3', @@ -291,9 +285,7 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The fruit in the image is a kiwi.')], - usage=RequestUsage( - input_tokens=1185, cache_read_tokens=0, output_tokens=11, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=1185, output_tokens=11, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_681134d53c48819198ce7b89db78dffd02cbfeaababb040c', @@ -424,9 +416,7 @@ async def test_openai_responses_model_builtin_tools(allow_model_requests: None, """ ) ], - usage=RequestUsage( - input_tokens=320, cache_read_tokens=0, output_tokens=159, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=320, output_tokens=159, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_689b7c90010c8196ac0efd68b021490f07450cfc2d48b975', @@ -449,9 +439,7 @@ async def test_openai_responses_model_instructions(allow_model_requests: None, o ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=RequestUsage( - input_tokens=24, cache_read_tokens=0, output_tokens=8, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=24, output_tokens=8, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_67f3fdfd9fa08191a3d5825db81b8df6003bc73febb56d77', @@ -694,9 +682,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=RequestUsage( - input_tokens=62, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=62, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7', @@ -719,9 +705,7 @@ async def get_user_country() -> str: tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', ) ], - usage=RequestUsage( - input_tokens=85, cache_read_tokens=0, output_tokens=20, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=85, output_tokens=20, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560', @@ -770,9 +754,7 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr') ], - usage=RequestUsage( - input_tokens=36, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=36, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a', @@ -789,9 +771,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=RequestUsage( - input_tokens=59, cache_read_tokens=0, output_tokens=11, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=59, output_tokens=11, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d', @@ -831,9 +811,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=RequestUsage( - input_tokens=66, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=66, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329', @@ -850,9 +828,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=RequestUsage( - input_tokens=89, cache_read_tokens=0, output_tokens=16, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=89, output_tokens=16, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f0fde708192989000a62809c6e5020197534e39cc1f', @@ -894,9 +870,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=RequestUsage( - input_tokens=153, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=153, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563', @@ -917,9 +891,7 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=RequestUsage( - input_tokens=176, cache_read_tokens=0, output_tokens=26, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=176, output_tokens=26, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68477f119830819da162aa6e10552035061ad97e2eef7871', @@ -964,9 +936,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=RequestUsage( - input_tokens=107, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=107, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50', @@ -990,9 +960,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=RequestUsage( - input_tokens=130, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=130, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115', @@ -1041,9 +1009,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=RequestUsage( - input_tokens=283, cache_read_tokens=0, output_tokens=12, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=283, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee', @@ -1071,9 +1037,7 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=RequestUsage( - input_tokens=306, cache_read_tokens=0, output_tokens=22, details={'reasoning_tokens': 0} - ), + usage=RequestUsage(input_tokens=306, output_tokens=22, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 30e8e6b076..d4146740b3 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -203,10 +203,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ], usage=RequestUsage( input_tokens=195, - cache_read_tokens=0, output_tokens=19, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -232,10 +229,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) parts=[TextPart(content='0 degrees Celsius is equal to 32 degrees Fahrenheit.')], usage=RequestUsage( input_tokens=227, - cache_read_tokens=0, output_tokens=13, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -339,10 +333,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): ], usage=RequestUsage( input_tokens=194, - cache_read_tokens=0, output_tokens=18, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -372,10 +363,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): ], usage=RequestUsage( input_tokens=234, - cache_read_tokens=0, output_tokens=19, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -416,10 +404,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ], usage=RequestUsage( input_tokens=200, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -445,10 +430,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A parts=[TextPart(content='The product name is "Pydantic AI".')], usage=RequestUsage( input_tokens=224, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -489,10 +471,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age ], usage=RequestUsage( input_tokens=305, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -518,10 +497,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age parts=[TextPart(content='The product name is "Pydantic AI".')], usage=RequestUsage( input_tokens=332, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -564,10 +540,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: ], usage=RequestUsage( input_tokens=191, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -598,10 +571,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: ], usage=RequestUsage( input_tokens=1332, - cache_read_tokens=0, output_tokens=19, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -646,10 +616,7 @@ async def test_tool_returning_image_resource_link( ], usage=RequestUsage( input_tokens=305, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -680,10 +647,7 @@ async def test_tool_returning_image_resource_link( ], usage=RequestUsage( input_tokens=1452, - cache_read_tokens=0, output_tokens=29, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -832,10 +796,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ], usage=RequestUsage( input_tokens=190, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -868,10 +829,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im parts=[TextPart(content='Here is an image of a sliced kiwi on a white background.')], usage=RequestUsage( input_tokens=1329, - cache_read_tokens=0, output_tokens=15, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -906,10 +864,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): parts=[ToolCallPart(tool_name='get_dict', args='{}', tool_call_id='call_oqKviITBj8PwpQjGyUu4Zu5x')], usage=RequestUsage( input_tokens=195, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -935,10 +890,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): parts=[TextPart(content='{"foo":"bar","baz":123}')], usage=RequestUsage( input_tokens=222, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -981,10 +933,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ], usage=RequestUsage( input_tokens=203, - cache_read_tokens=0, output_tokens=15, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1016,10 +965,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ], usage=RequestUsage( input_tokens=250, - cache_read_tokens=0, output_tokens=15, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1049,10 +995,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ], usage=RequestUsage( input_tokens=277, - cache_read_tokens=0, output_tokens=22, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1087,10 +1030,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): parts=[ToolCallPart(tool_name='get_none', args='{}', tool_call_id='call_mJTuQ2Cl5SaHPTJbIILEUhJC')], usage=RequestUsage( input_tokens=193, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1116,10 +1056,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): parts=[TextPart(content='Hello! How can I assist you today?')], usage=RequestUsage( input_tokens=212, - cache_read_tokens=0, output_tokens=11, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1162,10 +1099,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ], usage=RequestUsage( input_tokens=195, - cache_read_tokens=0, output_tokens=12, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -1207,10 +1141,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ], usage=RequestUsage( input_tokens=1355, - cache_read_tokens=0, output_tokens=24, - input_audio_tokens=0, - output_audio_tokens=0, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, From 6acce3fbbfa3245c8b92f889b95d60e3d59a9054 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:08:49 +0200 Subject: [PATCH 50/71] snapshots --- tests/models/test_mistral.py | 4 ++-- tests/test_agent.py | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 582086f768..c75fe176e0 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -447,7 +447,7 @@ class CityLocation(BaseModel): assert result.output == CityLocation(city='paris', country='france') assert result.usage().input_tokens == 1 assert result.usage().output_tokens == 1 - assert result.usage().details is None + assert result.usage().details == {} assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]), @@ -501,7 +501,7 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque assert result.output == 42 assert result.usage().input_tokens == 1 assert result.usage().output_tokens == 1 - assert result.usage().details is None + assert result.usage().details == {} assert result.all_messages() == snapshot( [ ModelRequest( diff --git a/tests/test_agent.py b/tests/test_agent.py index 0b3bbfb54d..d290def035 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2948,13 +2948,13 @@ def test_binary_content_serializable(): 'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}], 'usage': { 'input_tokens': 56, - 'cache_write_tokens': None, - 'cache_read_tokens': None, + 'cache_write_tokens': 0, + 'cache_read_tokens': 0, 'output_tokens': 4, - 'input_audio_tokens': None, - 'cache_audio_read_tokens': None, - 'output_audio_tokens': None, - 'details': None, + 'input_audio_tokens': 0, + 'cache_audio_read_tokens': 0, + 'output_audio_tokens': 0, + 'details': {}, }, 'model_name': 'test', 'provider_name': None, @@ -3003,13 +3003,13 @@ def test_image_url_serializable(): 'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}], 'usage': { 'input_tokens': 51, - 'cache_write_tokens': None, - 'cache_read_tokens': None, + 'cache_write_tokens': 0, + 'cache_read_tokens': 0, 'output_tokens': 4, - 'input_audio_tokens': None, - 'cache_audio_read_tokens': None, - 'output_audio_tokens': None, - 'details': None, + 'input_audio_tokens': 0, + 'cache_audio_read_tokens': 0, + 'output_audio_tokens': 0, + 'details': {}, }, 'model_name': 'test', 'timestamp': IsStr(), From 382540806546e2013b499e2b337831a0f3ae6105 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:18:38 +0200 Subject: [PATCH 51/71] cleanup --- docs/agents.md | 2 +- docs/models/openai.md | 4 +- docs/multi-agent-applications.md | 4 +- docs/output.md | 2 +- pydantic_ai_slim/pydantic_ai/usage.py | 98 +++++++++++---------------- 5 files changed, 47 insertions(+), 63 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 03a0f91f97..2185faf7c2 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -563,7 +563,7 @@ result_sync = agent.run_sync( print(result_sync.output) #> Rome print(result_sync.usage()) -#> RunUsage(requests=1, input_tokens=62, output_tokens=1) +#> RunUsage(input_tokens=62, output_tokens=1, requests=1) try: result_sync = agent.run_sync( diff --git a/docs/models/openai.md b/docs/models/openai.md index af12c2e256..81f07aebd9 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -272,7 +272,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> RunUsage(requests=1, input_tokens=57, output_tokens=8) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` #### Example using a remote server @@ -301,7 +301,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> RunUsage(requests=1, input_tokens=57, output_tokens=8) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` 1. The name of the model running on the remote server diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 62a66164ce..9fda4c1d2b 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -53,7 +53,7 @@ result = joke_selection_agent.run_sync( print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) -#> RunUsage(requests=3, input_tokens=204, output_tokens=24) +#> RunUsage(input_tokens=204, output_tokens=24, requests=3) ``` 1. The "parent" or controlling agent. @@ -144,7 +144,7 @@ async def main(): print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! - #> RunUsage(requests=4, input_tokens=309, output_tokens=32) + #> RunUsage(input_tokens=309, output_tokens=32, requests=4) ``` 1. Define a dataclass to hold the client and API key dependencies. diff --git a/docs/output.md b/docs/output.md index 28a46019ae..d0ba4ff06a 100644 --- a/docs/output.md +++ b/docs/output.md @@ -24,7 +24,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> RunUsage(requests=1, input_tokens=57, output_tokens=8) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` _(This example is complete, it can be run "as is")_ diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 160bd73ecc..62cdbadac5 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -2,7 +2,7 @@ import dataclasses from copy import copy -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing_extensions import deprecated, overload @@ -12,10 +12,43 @@ __all__ = 'RequestUsage', 'RunUsage', 'UsageLimits' +@dataclass(repr=False) class UsageBase: - input_tokens: int - output_tokens: int - details: dict[str, int] + input_tokens: int = 0 + """Number of input/prompt tokens.""" + + cache_write_tokens: int = 0 + """Number of tokens written to the cache.""" + cache_read_tokens: int = 0 + """Number of tokens read from the cache.""" + + output_tokens: int = 0 + """Number of output/completion tokens.""" + + input_audio_tokens: int = 0 + """Number of audio input tokens.""" + cache_audio_read_tokens: int = 0 + """Number of audio tokens read from the cache.""" + output_audio_tokens: int = 0 + """Number of audio output tokens.""" + + details: dict[str, int] = dataclasses.field(default_factory=dict) + """Any extra details returned by the model.""" + + @property + @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') + def request_tokens(self) -> int: + return self.input_tokens + + @property + @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') + def response_tokens(self) -> int: + return self.output_tokens + + @property + def total_tokens(self) -> int: + """Sum of `input_tokens + output_tokens`.""" + return self.input_tokens + self.output_tokens def opentelemetry_attributes(self) -> dict[str, int]: """Get the token usage values as OpenTelemetry attributes.""" @@ -33,11 +66,13 @@ def opentelemetry_attributes(self) -> dict[str, int]: result[prefix + key] = value return result - __repr__ = _utils.dataclasses_no_defaults_repr + def __repr__(self): + kv_pairs = (f'{f.name}={value!r}' for f in fields(self) if (value := getattr(self, f.name))) + return f'{self.__class__.__qualname__}({", ".join(kv_pairs)})' def has_values(self) -> bool: """Whether any values are set and non-zero.""" - return any(dataclasses.asdict(self).values()) # type: ignore + return any(dataclasses.asdict(self).values()) @dataclass(repr=False) @@ -50,46 +85,10 @@ class RequestUsage(UsageBase): Prices for LLM requests are calculated using [genai-prices](https://github.com/pydantic/genai-prices). """ - input_tokens: int = 0 - """Number of input/prompt tokens.""" - - cache_write_tokens: int = 0 - """Number of tokens written to the cache.""" - cache_read_tokens: int = 0 - """Number of tokens read from the cache.""" - - output_tokens: int = 0 - """Number of output/completion tokens.""" - - input_audio_tokens: int = 0 - """Number of audio input tokens.""" - cache_audio_read_tokens: int = 0 - """Number of audio tokens read from the cache.""" - output_audio_tokens: int = 0 - """Number of audio output tokens.""" - - details: dict[str, int] = dataclasses.field(default_factory=dict) - """Any extra details returned by the model.""" - @property def requests(self): return 1 - @property - @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') - def request_tokens(self) -> int | None: - return self.input_tokens - - @property - @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') - def response_tokens(self) -> int | None: - return self.output_tokens - - @property - @deprecated('`total_tokens` is deprecated, sum the specific fields you need instead') - def total_tokens(self) -> int | None: - return sum(v for k, v in dataclasses.asdict(self).values() if k.endswith('_tokens') and v is not None) - def incr(self, incr_usage: RequestUsage) -> None: """Increment the usage in place. @@ -139,21 +138,6 @@ class RunUsage(UsageBase): details: dict[str, int] = dataclasses.field(default_factory=dict) """Any extra details returned by the model.""" - @property - @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') - def request_tokens(self) -> int: - return self.input_tokens - - @property - @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') - def response_tokens(self) -> int: - return self.output_tokens - - @property - def total_tokens(self) -> int: - """Sum of `input_tokens + output_tokens`.""" - return self.input_tokens + self.output_tokens - def incr(self, incr_usage: RunUsage | RequestUsage) -> None: """Increment the usage in place. From 8dd479caeba116cdfab2c39c076cd5260236f110 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:24:46 +0200 Subject: [PATCH 52/71] coverage --- pydantic_ai_slim/pydantic_ai/models/_google_common.py | 6 +++--- tests/models/test_gemini.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/_google_common.py b/pydantic_ai_slim/pydantic_ai/models/_google_common.py index cbe5546648..b2362acfa2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/_google_common.py +++ b/pydantic_ai_slim/pydantic_ai/models/_google_common.py @@ -45,16 +45,16 @@ class GeminiUsageMetaData(TypedDict, total=False): def metadata_as_request_usage(metadata: GeminiUsageMetaData | None) -> usage.RequestUsage: if metadata is None: - return usage.RequestUsage() # pragma: no cover + return usage.RequestUsage() details: dict[str, int] = {} if cached_content_token_count := metadata.get('cached_content_token_count', 0): - details['cached_content_tokens'] = cached_content_token_count # pragma: no cover + details['cached_content_tokens'] = cached_content_token_count if thoughts_token_count := metadata.get('thoughts_token_count', 0): details['thoughts_tokens'] = thoughts_token_count if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0): - details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count input_audio_tokens = 0 output_audio_tokens = 0 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index cc097da4a9..53a7c73783 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -2180,3 +2180,11 @@ def test_map_usage(): }, ) ) + + +def test_map_empty_usage(): + response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')]))) + assert 'usage_metadata' in response + del response['usage_metadata'] + + assert _metadata_as_usage(response) == RequestUsage() From 62819c890c257cba89164e317fc85c6f83336323 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:48:58 +0200 Subject: [PATCH 53/71] cleanup --- tests/models/test_cohere.py | 1 - tests/models/test_gemini.py | 20 ++++++++++---------- tests/models/test_openai.py | 2 -- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 6dfed3b9a1..739397350e 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -277,7 +277,6 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=RequestUsage(details={}), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 53a7c73783..11d5c177df 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -622,7 +622,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -638,7 +638,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -646,7 +646,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -670,7 +670,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2, 123]}, tool_call_id=IsStr())], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -733,7 +733,7 @@ async def get_location(loc_name: str) -> str: parts=[ ToolCallPart(tool_name='get_location', args={'loc_name': 'San Fransisco'}, tool_call_id=IsStr()) ], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -753,7 +753,7 @@ async def get_location(loc_name: str) -> str: ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), ToolCallPart(tool_name='get_location', args={'loc_name': 'New York'}, tool_call_id=IsStr()), ], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -776,7 +776,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), provider_details={'finish_reason': 'STOP'}, @@ -941,7 +941,7 @@ async def bar(y: str) -> str: ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()), ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()), ], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -957,7 +957,7 @@ async def bar(y: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -1025,7 +1025,7 @@ def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=RequestUsage(input_tokens=1, output_tokens=2, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash', timestamp=IsDatetime(), ), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 273b53700e..3f44c5c9cb 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -364,7 +364,6 @@ async def get_location(loc_name: str) -> str: input_tokens=2, cache_read_tokens=1, output_tokens=1, - details={}, ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), @@ -392,7 +391,6 @@ async def get_location(loc_name: str) -> str: input_tokens=3, cache_read_tokens=2, output_tokens=2, - details={}, ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), From 5917a89febe20a047ecf4098d8c95c2ec4824096 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:50:20 +0200 Subject: [PATCH 54/71] cleanup --- tests/models/test_openai_responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index afec56ff3d..b17704d290 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -197,7 +197,7 @@ async def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=RequestUsage(input_tokens=0, output_tokens=0, details={'reasoning_tokens': 0}), + usage=RequestUsage(details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), provider_request_id='resp_67e547c48c9481918c5c4394464ce0c60ae6111e84dd5c08', From 43df6d0709c5f782a15ac2900b0babc0cf3a0c30 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:51:35 +0200 Subject: [PATCH 55/71] cleanup --- tests/test_live.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_live.py b/tests/test_live.py index 6c3eba8072..4770e8f82f 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -118,8 +118,7 @@ async def test_text(http_client: httpx.AsyncClient, tmp_path: Path, get_model: G assert 'paris' in result.output.lower() print('Text usage:', result.usage()) usage = result.usage() - total_tokens = usage.total_tokens - assert total_tokens is not None and total_tokens > 0 + assert usage.total_tokens is not None and usage.total_tokens > 0 stream_params = [p for p in params if p.id != 'cohere'] @@ -135,8 +134,7 @@ async def test_stream(http_client: httpx.AsyncClient, tmp_path: Path, get_model: print('Stream usage:', result.usage()) usage = result.usage() if get_model.__name__ != 'ollama': - total_tokens = usage.total_tokens - assert total_tokens is not None and total_tokens > 0 + assert usage.total_tokens is not None and usage.total_tokens > 0 class MyModel(BaseModel): @@ -164,5 +162,4 @@ async def test_structured(http_client: httpx.AsyncClient, tmp_path: Path, get_mo assert result.output.city.lower() == 'london' print('Structured usage:', result.usage()) usage = result.usage() - total_tokens = usage.total_tokens - assert total_tokens is not None and total_tokens > 0 + assert usage.total_tokens is not None and usage.total_tokens > 0 From f9464a41b13ef07ae1c6bd4ccd6fcd517985bf91 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 19:52:44 +0200 Subject: [PATCH 56/71] cleanup --- tests/test_streaming.py | 6 +++--- tests/test_usage_limits.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 624dc8399e..139937a030 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -64,7 +64,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=RequestUsage(input_tokens=51, output_tokens=0), + usage=RequestUsage(input_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -93,7 +93,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=RequestUsage(input_tokens=51, output_tokens=0), + usage=RequestUsage(input_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -782,7 +782,7 @@ def regular_tool(x: int) -> int: ), ModelResponse( parts=[ToolCallPart(tool_name='regular_tool', args={'x': 0}, tool_call_id=IsStr())], - usage=RequestUsage(input_tokens=57, output_tokens=0), + usage=RequestUsage(input_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index a8feb3af37..e476695588 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -98,7 +98,7 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=RequestUsage(input_tokens=51, output_tokens=0), + usage=RequestUsage(input_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), From a9c64a66a297d46f67268cdf1916a2010509e0fe Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 20:04:25 +0200 Subject: [PATCH 57/71] simplify _incr_usage_tokens --- pydantic_ai_slim/pydantic_ai/usage.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 62cdbadac5..b30f51c144 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -165,23 +165,15 @@ def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | Requ slf: The usage to increment. incr_usage: The usage to increment by. """ - if incr_usage.input_tokens: - slf.input_tokens = (slf.input_tokens or 0) + incr_usage.input_tokens - if incr_usage.output_tokens: - slf.output_tokens = (slf.output_tokens or 0) + incr_usage.output_tokens - if incr_usage.cache_write_tokens: - slf.cache_write_tokens = (slf.cache_write_tokens or 0) + incr_usage.cache_write_tokens - if incr_usage.cache_read_tokens: - slf.cache_read_tokens = (slf.cache_read_tokens or 0) + incr_usage.cache_read_tokens - if incr_usage.input_audio_tokens: - slf.input_audio_tokens = (slf.input_audio_tokens or 0) + incr_usage.input_audio_tokens - if incr_usage.cache_audio_read_tokens: - slf.cache_audio_read_tokens = (slf.cache_audio_read_tokens or 0) + incr_usage.cache_audio_read_tokens - - if incr_usage.details: - slf.details = slf.details or {} - for key, value in incr_usage.details.items(): - slf.details[key] = slf.details.get(key, 0) + value + slf.input_tokens += incr_usage.input_tokens + slf.cache_write_tokens += incr_usage.cache_write_tokens + slf.cache_read_tokens += incr_usage.cache_read_tokens + slf.input_audio_tokens += incr_usage.input_audio_tokens + slf.cache_audio_read_tokens += incr_usage.cache_audio_read_tokens + slf.output_tokens += incr_usage.output_tokens + + for key, value in incr_usage.details.items(): + slf.details[key] = slf.details.get(key, 0) + value @dataclass From fa5e73ebad5bb5aee5d5e0737fb8fcf332080b70 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 20:07:53 +0200 Subject: [PATCH 58/71] Test openai audio tokens --- tests/models/test_openai.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 3f44c5c9cb..0b5c3b76d9 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -783,6 +783,21 @@ async def test_openai_audio_url_input(allow_model_requests: None, openai_api_key assert result.output == snapshot( 'Yes, the phenomenon of the sun rising in the east and setting in the west is due to the rotation of the Earth. The Earth rotates on its axis from west to east, making the sun appear to rise on the eastern horizon and set in the west. This is a daily occurrence and has been a fundamental aspect of human observation and timekeeping throughout history.' ) + assert result.usage() == snapshot( + RunUsage( + input_tokens=81, + output_tokens=72, + input_audio_tokens=69, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'text_tokens': 72, + }, + requests=1, + ) + ) async def test_document_url_input(allow_model_requests: None, openai_api_key: str): @@ -962,6 +977,21 @@ async def test_audio_as_binary_content_input( result = await agent.run(['Whose name is mentioned in the audio?', audio_content]) assert result.output == snapshot('The name mentioned in the audio is Marcelo.') + assert result.usage() == snapshot( + RunUsage( + input_tokens=64, + output_tokens=9, + input_audio_tokens=44, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'text_tokens': 9, + }, + requests=1, + ) + ) async def test_document_as_binary_content_input( From 4d0f8be1d1530f3c2ea962c25c012689e826da92 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 20:12:24 +0200 Subject: [PATCH 59/71] cleanup --- pydantic_ai_slim/pydantic_ai/usage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index b30f51c144..1a63f420ad 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -268,11 +268,11 @@ def check_before_request(self, usage: RunUsage) -> None: def check_tokens(self, usage: RunUsage) -> None: """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" - input_tokens = usage.input_tokens or 0 + input_tokens = usage.input_tokens if self.input_tokens_limit is not None and input_tokens > self.input_tokens_limit: raise UsageLimitExceeded(f'Exceeded the input_tokens_limit of {self.input_tokens_limit} ({input_tokens=})') - output_tokens = usage.output_tokens or 0 + output_tokens = usage.output_tokens if self.output_tokens_limit is not None and output_tokens > self.output_tokens_limit: raise UsageLimitExceeded( f'Exceeded the output_tokens_limit of {self.output_tokens_limit} ({output_tokens=})' From 3d4646219ff68a4fbfe9f592b6d8518bdd8fbd48 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Wed, 13 Aug 2025 20:38:09 +0200 Subject: [PATCH 60/71] docstring --- pydantic_ai_slim/pydantic_ai/usage.py | 4 +- .../test_known_model_names.yaml | 84 ------------------- 2 files changed, 1 insertion(+), 87 deletions(-) delete mode 100644 tests/models/cassettes/test_model_names/test_known_model_names.yaml diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 1a63f420ad..d3743919bb 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -80,9 +80,7 @@ class RequestUsage(UsageBase): """LLM usage associated with a single request. This is an implementation of `genai_prices.types.AbstractUsage` so it can be used to calculate the price of the - request. - - Prices for LLM requests are calculated using [genai-prices](https://github.com/pydantic/genai-prices). + request using [genai-prices](https://github.com/pydantic/genai-prices). """ @property diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml deleted file mode 100644 index e66c3747bd..0000000000 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ /dev/null @@ -1,84 +0,0 @@ -interactions: -- request: - body: '' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - host: - - us.inference.heroku.com - method: GET - uri: https://us.inference.heroku.com/available-models - response: - headers: - content-length: - - '760' - content-security-policy: - - default-src 'none'; frame-ancestors 'none' - content-type: - - application/json - strict-transport-security: - - max-age=63072000 - parsed_body: - - model_id: claude-3-5-haiku - regions: - - us - type: - - text-to-text - - model_id: claude-3-5-sonnet-latest - regions: - - us - type: - - text-to-text - - model_id: claude-3-7-sonnet - regions: - - eu - - us - type: - - text-to-text - - model_id: claude-3-haiku - regions: - - eu - type: - - text-to-text - - model_id: claude-4-sonnet - regions: - - eu - - us - type: - - text-to-text - - model_id: cohere-embed-multilingual - regions: - - eu - - us - type: - - text-to-embedding - - model_id: gpt-oss-120b - regions: - - us - type: - - text-to-text - - model_id: nova-lite - regions: - - eu - - us - type: - - text-to-text - - model_id: nova-pro - regions: - - eu - - us - type: - - text-to-text - - model_id: stable-image-ultra - regions: - - us - type: - - text-to-image - status: - code: 200 - message: OK -version: 1 From 7c0599bdd60894919a2b90e36d113a5e787f0605 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Thu, 14 Aug 2025 18:30:26 +0200 Subject: [PATCH 61/71] Remove provider_name and price --- pydantic_ai_slim/pydantic_ai/messages.py | 14 -------------- tests/models/test_anthropic.py | 2 -- tests/test_agent.py | 2 -- 3 files changed, 18 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 59dfaa7917..28447187ef 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -10,7 +10,6 @@ import pydantic import pydantic_core -from genai_prices import calc_price, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] from typing_extensions import TypeAlias, deprecated @@ -849,9 +848,6 @@ class ModelResponse: kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" - provider_name: str | None = None - """The name of the LLM provider that generated the response.""" - provider_details: dict[str, Any] | None = field(default=None) """Additional provider-specific details in a serializable format. @@ -862,16 +858,6 @@ class ModelResponse: provider_request_id: str | None = None """request ID as specified by the model provider. This can be used to track the specific request to the model.""" - def price(self) -> genai_types.PriceCalculation: - """Calculate the price of the usage.""" - assert self.model_name, 'Model name is required to calculate price' - return calc_price( - self.usage, - self.model_name, - provider_id=self.provider_name, - genai_request_timestamp=self.timestamp, - ) - def otel_events(self, settings: InstrumentationSettings) -> list[Event]: """Return OpenTelemetry events for the response.""" result: list[Event] = [] diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index c970ca126b..b6aa73b628 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from dataclasses import dataclass, field from datetime import timezone -from decimal import Decimal from functools import cached_property from typing import Any, Callable, TypeVar, Union, cast @@ -250,7 +249,6 @@ async def test_async_request_prompt_caching(allow_model_requests: None): ) last_message = result.all_messages()[-1] assert isinstance(last_message, ModelResponse) - assert last_message.price().total_price == snapshot(Decimal('0.00003488')) async def test_async_request_text_response(allow_model_requests: None): diff --git a/tests/test_agent.py b/tests/test_agent.py index 921d83b6ba..94ff189e3f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2957,7 +2957,6 @@ def test_binary_content_serializable(): 'details': {}, }, 'model_name': 'test', - 'provider_name': None, 'provider_details': None, 'provider_request_id': None, 'timestamp': IsStr(), @@ -3013,7 +3012,6 @@ def test_image_url_serializable(): }, 'model_name': 'test', 'timestamp': IsStr(), - 'provider_name': None, 'provider_details': None, 'provider_request_id': None, 'kind': 'response', From a3ead91f93b6724cda649b89daf7b23926af25a5 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Thu, 14 Aug 2025 18:45:52 +0200 Subject: [PATCH 62/71] deprecated properties --- pydantic_ai_slim/pydantic_ai/usage.py | 10 ++++++++++ tests/test_usage_limits.py | 13 +++++++++++++ 2 files changed, 23 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 9ba75ec969..1b8f11556e 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -203,6 +203,16 @@ class UsageLimits: to enforce `request_tokens_limit` ahead of time. This may incur additional overhead (from calling the model's `count_tokens` API before making the actual request) and is disabled by default.""" + @property + @deprecated('`request_tokens_limit` is deprecated, use `input_tokens_limit` instead') + def request_tokens_limit(self) -> int | None: + return self.input_tokens_limit + + @property + @deprecated('`response_tokens_limit` is deprecated, use `output_tokens_limit` instead') + def response_tokens_limit(self) -> int | None: + return self.output_tokens_limit + @overload def __init__( self, diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index e476695588..ef86989923 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -7,6 +7,7 @@ import pytest from genai_prices import Usage as GenaiPricesUsage, calc_price from inline_snapshot import snapshot +from inline_snapshot.extra import warns from pydantic_ai import Agent, RunContext, UsageLimitExceeded from pydantic_ai.messages import ( @@ -233,3 +234,15 @@ def test_add_usages(): ) assert usage + RunUsage() == usage assert RunUsage() + RunUsage() == RunUsage() + + +def test_deprecated_usage_limits(): + with warns( + snapshot(['DeprecationWarning: `request_tokens_limit` is deprecated, use `input_tokens_limit` instead']) + ): + assert UsageLimits(input_tokens_limit=100).request_tokens_limit == 100 # type: ignore + + with warns( + snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead']) + ): + assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore From ff32e9e4d38b87bd1614a83f7d7af2ee4d2f73b1 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 13:07:00 +0200 Subject: [PATCH 63/71] fix --- tests/test_agent.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index a99ace5ee9..20b40b13a7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2046,7 +2046,7 @@ def test_tool() -> str: ), ModelResponse( parts=[TextPart(content='Final response')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='function:simple_response:', timestamp=IsDatetime(), ), @@ -4240,7 +4240,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon ), ModelResponse( parts=[ThinkingPart(content='Let me think about this...')], - usage=Usage(requests=1, request_tokens=57, response_tokens=6, total_tokens=63), + usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4255,7 +4255,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon ), ModelResponse( parts=[TextPart(content='Final answer')], - usage=Usage(requests=1, request_tokens=75, response_tokens=8, total_tokens=83), + usage=RequestUsage(input_tokens=75, output_tokens=8), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4324,7 +4324,7 @@ def delete_file(ctx: RunContext[ApprovableToolsDeps], path: str) -> str: tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ], - usage=Usage(requests=1, request_tokens=57, response_tokens=12, total_tokens=69), + usage=RequestUsage(input_tokens=57, output_tokens=12), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4373,7 +4373,7 @@ def delete_file(ctx: RunContext[ApprovableToolsDeps], path: str) -> str: tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ], - usage=Usage(requests=1, request_tokens=57, response_tokens=12, total_tokens=69), + usage=RequestUsage(input_tokens=57, output_tokens=12), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4395,7 +4395,7 @@ def delete_file(ctx: RunContext[ApprovableToolsDeps], path: str) -> str: ), ModelResponse( parts=[TextPart(content='OK')], - usage=Usage(requests=1, request_tokens=76, response_tokens=13, total_tokens=89), + usage=RequestUsage(input_tokens=76, output_tokens=13), model_name='function:model_function:', timestamp=IsDatetime(), ), From d86952ffb76678bafc88cd3c052d19af928c3abc Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 13:54:10 +0200 Subject: [PATCH 64/71] fix __all__ --- pydantic_ai_slim/pydantic_ai/usage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 1b8f11556e..19cfb0acb7 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -9,7 +9,7 @@ from . import _utils from .exceptions import UsageLimitExceeded -__all__ = 'RequestUsage', 'RunUsage', 'UsageLimits' +__all__ = 'Usage', 'RequestUsage', 'RunUsage', 'UsageLimits' @dataclass(repr=False) From b1aff2a7c1f603664ef0079fcfde31c737460c33 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 13:54:36 +0200 Subject: [PATCH 65/71] fix __all__ --- pydantic_ai_slim/pydantic_ai/usage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 19cfb0acb7..4da1446d59 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -9,7 +9,7 @@ from . import _utils from .exceptions import UsageLimitExceeded -__all__ = 'Usage', 'RequestUsage', 'RunUsage', 'UsageLimits' +__all__ = 'RequestUsage', 'RunUsage', 'Usage', 'UsageLimits' @dataclass(repr=False) From b63c7b4f6d4c66af43be913501877a8dfb6958ac Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 13:55:00 +0200 Subject: [PATCH 66/71] docstring --- pydantic_ai_slim/pydantic_ai/usage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 4da1446d59..a16f4ee09f 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -177,7 +177,7 @@ def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | Requ @dataclass @deprecated('`Usage` is deprecated, use `RunUsage` instead') class Usage(RunUsage): - pass + """Deprecated alias for `RunUsage`.""" @dataclass(repr=False) From 6a789b187f70ad061c28d875a1971811c193b9ad Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 14:08:27 +0200 Subject: [PATCH 67/71] separate google and gemini --- .../pydantic_ai/models/_google_common.py | 85 ------------------- pydantic_ai_slim/pydantic_ai/models/gemini.py | 76 ++++++++++++++++- pydantic_ai_slim/pydantic_ai/models/google.py | 47 +++++++++- tests/models/test_gemini.py | 2 +- 4 files changed, 119 insertions(+), 91 deletions(-) delete mode 100644 pydantic_ai_slim/pydantic_ai/models/_google_common.py diff --git a/pydantic_ai_slim/pydantic_ai/models/_google_common.py b/pydantic_ai_slim/pydantic_ai/models/_google_common.py deleted file mode 100644 index b2362acfa2..0000000000 --- a/pydantic_ai_slim/pydantic_ai/models/_google_common.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -from typing import Annotated, Literal, cast - -import pydantic -from typing_extensions import NotRequired, TypedDict - -from pydantic_ai import usage - - -class _GeminiModalityTokenCount(TypedDict): - """See .""" - - modality: Annotated[ - Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality') - ] - token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)] - - -class GeminiUsageMetaData(TypedDict, total=False): - """See . - - The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. - """ - - prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] - candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] - total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] - cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] - thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]] - tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]] - prompt_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')] - ] - cache_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')] - ] - candidates_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')] - ] - tool_use_prompt_tokens_details: NotRequired[ - Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')] - ] - - -def metadata_as_request_usage(metadata: GeminiUsageMetaData | None) -> usage.RequestUsage: - if metadata is None: - return usage.RequestUsage() - details: dict[str, int] = {} - if cached_content_token_count := metadata.get('cached_content_token_count', 0): - details['cached_content_tokens'] = cached_content_token_count - - if thoughts_token_count := metadata.get('thoughts_token_count', 0): - details['thoughts_tokens'] = thoughts_token_count - - if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0): - details['tool_use_prompt_tokens'] = tool_use_prompt_token_count - - input_audio_tokens = 0 - output_audio_tokens = 0 - cache_audio_read_tokens = 0 - for key, metadata_details in metadata.items(): - if key.endswith('_details') and metadata_details: - metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) - suffix = key.removesuffix('_details') - for detail in metadata_details: - modality = detail['modality'] - details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0) - if value and modality == 'AUDIO': - if key == 'prompt_tokens_details': - input_audio_tokens = value - elif key == 'candidates_tokens_details': - output_audio_tokens = value - elif key == 'cache_tokens_details': # pragma: no branch - cache_audio_read_tokens = value - - return usage.RequestUsage( - input_tokens=metadata.get('prompt_token_count', 0), - output_tokens=metadata.get('candidates_token_count', 0), - cache_read_tokens=cached_content_token_count, - input_audio_tokens=input_audio_tokens, - output_audio_tokens=output_audio_tokens, - cache_audio_read_tokens=cache_audio_read_tokens, - details=details, - ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4d3c6743f0..92feaa7d1c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -48,7 +48,6 @@ download_item, get_user_agent, ) -from ._google_common import GeminiUsageMetaData, metadata_as_request_usage LatestGeminiModelNames = Literal[ 'gemini-2.0-flash', @@ -824,9 +823,82 @@ class _GeminiCandidates(TypedDict): safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]] +class _GeminiModalityTokenCount(TypedDict): + """See .""" + + modality: Annotated[ + Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality') + ] + token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)] + + +class GeminiUsageMetaData(TypedDict, total=False): + """See . + + The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. + """ + + prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] + candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] + total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] + cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] + thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]] + tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]] + prompt_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')] + ] + cache_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')] + ] + candidates_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')] + ] + tool_use_prompt_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')] + ] + + def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage: metadata = response.get('usage_metadata') - return metadata_as_request_usage(metadata) + if metadata is None: + return usage.RequestUsage() + details: dict[str, int] = {} + if cached_content_token_count := metadata.get('cached_content_token_count', 0): + details['cached_content_tokens'] = cached_content_token_count + + if thoughts_token_count := metadata.get('thoughts_token_count', 0): + details['thoughts_tokens'] = thoughts_token_count + + if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0): + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count + + input_audio_tokens = 0 + output_audio_tokens = 0 + cache_audio_read_tokens = 0 + for key, metadata_details in metadata.items(): + if key.endswith('_details') and metadata_details: + metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) + suffix = key.removesuffix('_details') + for detail in metadata_details: + modality = detail['modality'] + details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0) + if value and modality == 'AUDIO': + if key == 'prompt_tokens_details': + input_audio_tokens = value + elif key == 'candidates_tokens_details': + output_audio_tokens = value + elif key == 'cache_tokens_details': # pragma: no branch + cache_audio_read_tokens = value + + return usage.RequestUsage( + input_tokens=metadata.get('prompt_token_count', 0), + output_tokens=metadata.get('candidates_token_count', 0), + cache_read_tokens=cached_content_token_count, + input_audio_tokens=input_audio_tokens, + output_audio_tokens=output_audio_tokens, + cache_audio_read_tokens=cache_audio_read_tokens, + details=details, + ) class _GeminiSafetyRating(TypedDict): diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 0d3c021026..1c5152dc0b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -46,7 +46,6 @@ download_item, get_user_agent, ) -from ._google_common import GeminiUsageMetaData, metadata_as_request_usage try: from google.genai import Client @@ -654,5 +653,47 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage: metadata = response.usage_metadata if metadata is None: return usage.RequestUsage() # pragma: no cover - metadata = cast(GeminiUsageMetaData, metadata.model_dump(exclude_defaults=True)) - return metadata_as_request_usage(metadata) + details: dict[str, int] = {} + if cached_content_token_count := metadata.cached_content_token_count: + details['cached_content_tokens'] = cached_content_token_count + + if thoughts_token_count := metadata.thoughts_token_count: + details['thoughts_tokens'] = thoughts_token_count + + if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count: + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count + + input_audio_tokens = 0 + output_audio_tokens = 0 + cache_audio_read_tokens = 0 + for prefix, metadata_details in [ + ('prompt', metadata.prompt_tokens_details), + ('cache', metadata.cache_tokens_details), + ('candidates', metadata.candidates_tokens_details), + ('tool_use_prompt', metadata.tool_use_prompt_tokens_details), + ]: + assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details + if not metadata_details: + continue + for detail in metadata_details: + if not detail.modality or not detail.token_count: + continue + details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count + if detail.modality != 'AUDIO': + continue + if metadata_details is metadata.prompt_tokens_details: + input_audio_tokens = detail.token_count + elif metadata_details is metadata.candidates_tokens_details: + output_audio_tokens = detail.token_count + elif metadata_details is metadata.cache_tokens_details: # pragma: no branch + cache_audio_read_tokens = detail.token_count + + return usage.RequestUsage( + input_tokens=metadata.prompt_token_count or 0, + output_tokens=metadata.candidates_token_count or 0, + cache_read_tokens=cached_content_token_count or 0, + input_audio_tokens=input_audio_tokens, + output_audio_tokens=output_audio_tokens, + cache_audio_read_tokens=cache_audio_read_tokens, + details=details, + ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 11d5c177df..4bd616846a 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -34,7 +34,6 @@ VideoUrl, ) from pydantic_ai.models import ModelRequestParameters -from pydantic_ai.models._google_common import _GeminiModalityTokenCount from pydantic_ai.models.gemini import ( GeminiModel, GeminiModelSettings, @@ -48,6 +47,7 @@ _GeminiFunctionCall, _GeminiFunctionCallingConfig, _GeminiFunctionCallPart, + _GeminiModalityTokenCount, _GeminiResponse, _GeminiSafetyRating, _GeminiTextPart, From f44714774066475d4d8a1152a81691976bdec1c0 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 14:10:08 +0200 Subject: [PATCH 68/71] rename --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 4 ++-- tests/models/test_gemini.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 92feaa7d1c..0e223cb190 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -803,7 +803,7 @@ class _GeminiResponse(TypedDict): candidates: list[_GeminiCandidates] # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response - usage_metadata: NotRequired[Annotated[GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] + usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]] vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] @@ -832,7 +832,7 @@ class _GeminiModalityTokenCount(TypedDict): token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)] -class GeminiUsageMetaData(TypedDict, total=False): +class _GeminiUsageMetaData(TypedDict, total=False): """See . The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 4bd616846a..a5e7a97fd5 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -37,7 +37,6 @@ from pydantic_ai.models.gemini import ( GeminiModel, GeminiModelSettings, - GeminiUsageMetaData, _content_model_response, _gemini_response_ta, _gemini_streamed_response_ta, @@ -54,6 +53,7 @@ _GeminiThoughtPart, _GeminiToolConfig, _GeminiTools, + _GeminiUsageMetaData, _metadata_as_usage, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput @@ -605,8 +605,8 @@ def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | No return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123') -def example_usage() -> GeminiUsageMetaData: - return GeminiUsageMetaData(prompt_token_count=1, candidates_token_count=2, total_token_count=3) +def example_usage() -> _GeminiUsageMetaData: + return _GeminiUsageMetaData(prompt_token_count=1, candidates_token_count=2, total_token_count=3) async def test_text_success(get_gemini_client: GetGeminiClient): From 0392c846b528aa7ec5c070adbf108146baec9bc3 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 15:18:53 +0200 Subject: [PATCH 69/71] coverage --- tests/models/test_gemini.py | 2 ++ tests/models/test_google.py | 47 +++++++++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a5e7a97fd5..bd0b61545c 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -2163,6 +2163,7 @@ def test_map_usage(): response['usage_metadata']['candidates_tokens_details'] = [ _GeminiModalityTokenCount(modality='AUDIO', token_count=9400) ] + response['usage_metadata']['thoughts_token_count'] = 9500 assert _metadata_as_usage(response) == snapshot( RequestUsage( @@ -2174,6 +2175,7 @@ def test_map_usage(): output_audio_tokens=9400, details={ 'cached_content_tokens': 9100, + 'thoughts_tokens': 9500, 'audio_prompt_tokens': 9200, 'audio_cache_tokens': 9300, 'audio_candidates_tokens': 9400, diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 722423c6cb..d94703e699 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -47,9 +47,19 @@ from ..parts_from_messages import part_types_from_messages with try_import() as imports_successful: - from google.genai.types import CodeExecutionResult, HarmBlockThreshold, HarmCategory, Language, Outcome + from google.genai.types import ( + CodeExecutionResult, + GenerateContentResponse, + GenerateContentResponseUsageMetadata, + HarmBlockThreshold, + HarmCategory, + Language, + MediaModality, + ModalityTokenCount, + Outcome, + ) - from pydantic_ai.models.google import GoogleModel, GoogleModelSettings + from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, _metadata_as_usage # type: ignore from pydantic_ai.providers.google import GoogleProvider pytestmark = [ @@ -1663,3 +1673,36 @@ async def get_user_country() -> str: 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True), ) + + +def test_map_usage(): + assert _metadata_as_usage(GenerateContentResponse()) == RequestUsage() + + response = GenerateContentResponse( + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=1, + candidates_token_count=2, + cached_content_token_count=9100, + thoughts_token_count=9500, + prompt_tokens_details=[ModalityTokenCount(modality=MediaModality.AUDIO, token_count=9200)], + cache_tokens_details=[ModalityTokenCount(modality=MediaModality.AUDIO, token_count=9300)], + candidates_tokens_details=[ModalityTokenCount(modality=MediaModality.AUDIO, token_count=9400)], + ) + ) + assert _metadata_as_usage(response) == snapshot( + RequestUsage( + input_tokens=1, + cache_read_tokens=9100, + output_tokens=2, + input_audio_tokens=9200, + cache_audio_read_tokens=9300, + output_audio_tokens=9400, + details={ + 'cached_content_tokens': 9100, + 'thoughts_tokens': 9500, + 'audio_prompt_tokens': 9200, + 'audio_cache_tokens': 9300, + 'audio_candidates_tokens': 9400, + }, + ) + ) From 8368c117dfc0ca68837105a0e4cf9933c7d65a8f Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 15:25:37 +0200 Subject: [PATCH 70/71] coverage --- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- tests/models/test_gemini.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 1c5152dc0b..d37b8f5a0f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -676,7 +676,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage: if not metadata_details: continue for detail in metadata_details: - if not detail.modality or not detail.token_count: + if not detail.modality or not detail.token_count: # pragma: no cover continue details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count if detail.modality != 'AUDIO': diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index bd0b61545c..73de9b971f 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -2164,6 +2164,7 @@ def test_map_usage(): _GeminiModalityTokenCount(modality='AUDIO', token_count=9400) ] response['usage_metadata']['thoughts_token_count'] = 9500 + response['usage_metadata']['tool_use_prompt_token_count'] = 9600 assert _metadata_as_usage(response) == snapshot( RequestUsage( @@ -2175,10 +2176,11 @@ def test_map_usage(): output_audio_tokens=9400, details={ 'cached_content_tokens': 9100, - 'thoughts_tokens': 9500, 'audio_prompt_tokens': 9200, 'audio_cache_tokens': 9300, 'audio_candidates_tokens': 9400, + 'thoughts_tokens': 9500, + 'tool_use_prompt_tokens': 9600, }, ) ) From f1e9abf3338a99a85a9ecb11b56d2ff1d2f480a8 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Fri, 15 Aug 2025 15:30:18 +0200 Subject: [PATCH 71/71] coverage --- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index d37b8f5a0f..a2f1689c71 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -652,7 +652,7 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict: def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage: metadata = response.usage_metadata if metadata is None: - return usage.RequestUsage() # pragma: no cover + return usage.RequestUsage() details: dict[str, int] = {} if cached_content_token_count := metadata.cached_content_token_count: details['cached_content_tokens'] = cached_content_token_count