diff --git a/docs/agents.md b/docs/agents.md index dbe72301dc..2185faf7c2 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -302,9 +302,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -367,12 +365,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, - request_tokens=56, - response_tokens=7, - total_tokens=63, - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -391,7 +384,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci #### Accessing usage and final output -You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`][pydantic_ai.agent.AgentRun] object via `agent_run.usage()`. This method returns a [`Usage`][pydantic_ai.usage.Usage] object containing the usage data. +You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`][pydantic_ai.agent.AgentRun] object via `agent_run.usage()`. This method returns a [`RunUsage`][pydantic_ai.usage.RunUsage] object containing the usage data. Once the run finishes, `agent_run.result` becomes a [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] object containing the final output (and related metadata). @@ -570,7 +563,7 @@ result_sync = agent.run_sync( print(result_sync.output) #> Rome print(result_sync.usage()) -#> Usage(requests=1, request_tokens=62, response_tokens=1, total_tokens=63) +#> RunUsage(input_tokens=62, output_tokens=1, requests=1) try: result_sync = agent.run_sync( @@ -579,7 +572,7 @@ try: ) except UsageLimitExceeded as e: print(e) - #> Exceeded the response_tokens_limit of 10 (response_tokens=32) + #> Exceeded the output_tokens_limit of 10 (output_tokens=32) ``` Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling: @@ -1018,9 +1011,7 @@ with capture_run_messages() as messages: # (2)! tool_call_id='pyd_ai_tool_call_id', ) ], - usage=Usage( - requests=1, request_tokens=62, response_tokens=4, total_tokens=66 - ), + usage=RequestUsage(input_tokens=62, output_tokens=4), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -1042,9 +1033,7 @@ with capture_run_messages() as messages: # (2)! tool_call_id='pyd_ai_tool_call_id', ) ], - usage=Usage( - requests=1, request_tokens=72, response_tokens=8, total_tokens=80 - ), + usage=RequestUsage(input_tokens=72, output_tokens=8), model_name='gpt-4o', timestamp=datetime.datetime(...), ), diff --git a/docs/direct.md b/docs/direct.md index 28e3ffcbf3..c65e4817bf 100644 --- a/docs/direct.md +++ b/docs/direct.md @@ -28,7 +28,7 @@ model_response = model_request_sync( print(model_response.parts[0].content) #> The capital of France is Paris. print(model_response.usage) -#> Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63) +#> RequestUsage(input_tokens=56, output_tokens=7) ``` _(This example is complete, it can be run "as is")_ @@ -83,7 +83,7 @@ async def main(): tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b', ) ], - usage=Usage(requests=1, request_tokens=55, response_tokens=7, total_tokens=62), + usage=RequestUsage(input_tokens=55, output_tokens=7), model_name='gpt-4.1-nano', timestamp=datetime.datetime(...), ) diff --git a/docs/message-history.md b/docs/message-history.md index 1a2bcf38b8..33d9f99563 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -58,7 +58,7 @@ print(result.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), + usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -126,7 +126,7 @@ async def main(): content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(request_tokens=50, response_tokens=12, total_tokens=62), + usage=RequestUsage(input_tokens=50, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -180,7 +180,7 @@ print(result2.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), + usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -198,7 +198,7 @@ print(result2.all_messages()) content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], - usage=Usage(requests=1, request_tokens=61, response_tokens=26, total_tokens=87), + usage=RequestUsage(input_tokens=61, output_tokens=26), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -299,7 +299,7 @@ print(result2.all_messages()) content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], - usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), + usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), ), @@ -317,7 +317,7 @@ print(result2.all_messages()) content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], - usage=Usage(requests=1, request_tokens=61, response_tokens=26, total_tokens=87), + usage=RequestUsage(input_tokens=61, output_tokens=26), model_name='gemini-1.5-pro', timestamp=datetime.datetime(...), ), diff --git a/docs/models/index.md b/docs/models/index.md index 06410ec4a9..b7bcb09b71 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -117,7 +117,7 @@ print(response.all_messages()) model_name='claude-3-5-sonnet-latest', timestamp=datetime.datetime(...), kind='response', - vendor_id=None, + provider_request_id=None, ), ] """ diff --git a/docs/models/openai.md b/docs/models/openai.md index f2dc7efa5a..91704c4d13 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -275,7 +275,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` #### Example using a remote server @@ -304,7 +304,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` 1. The name of the model running on the remote server diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index a531f87b86..9fda4c1d2b 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -53,7 +53,7 @@ result = joke_selection_agent.run_sync( print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) -#> Usage(requests=3, request_tokens=204, response_tokens=24, total_tokens=228) +#> RunUsage(input_tokens=204, output_tokens=24, requests=3) ``` 1. The "parent" or controlling agent. @@ -144,7 +144,7 @@ async def main(): print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! - #> Usage(requests=4, request_tokens=309, response_tokens=32, total_tokens=341) + #> RunUsage(input_tokens=309, output_tokens=32, requests=4) ``` 1. Define a dataclass to hold the client and API key dependencies. @@ -188,7 +188,7 @@ from rich.prompt import Prompt from pydantic_ai import Agent, RunContext from pydantic_ai.messages import ModelMessage -from pydantic_ai.usage import Usage, UsageLimits +from pydantic_ai.usage import RunUsage, UsageLimits class FlightDetails(BaseModel): @@ -221,7 +221,7 @@ async def flight_search( usage_limits = UsageLimits(request_limit=15) # (3)! -async def find_flight(usage: Usage) -> Union[FlightDetails, None]: # (4)! +async def find_flight(usage: RunUsage) -> Union[FlightDetails, None]: # (4)! message_history: Union[list[ModelMessage], None] = None for _ in range(3): prompt = Prompt.ask( @@ -259,7 +259,7 @@ seat_preference_agent = Agent[None, Union[SeatPreference, Failed]]( # (5)! ) -async def find_seat(usage: Usage) -> SeatPreference: # (6)! +async def find_seat(usage: RunUsage) -> SeatPreference: # (6)! message_history: Union[list[ModelMessage], None] = None while True: answer = Prompt.ask('What seat would you like?') @@ -278,7 +278,7 @@ async def find_seat(usage: Usage) -> SeatPreference: # (6)! async def main(): # (7)! - usage: Usage = Usage() + usage: RunUsage = RunUsage() opt_flight_details = await find_flight(usage) if opt_flight_details is not None: diff --git a/docs/output.md b/docs/output.md index c04425eb1e..d0ba4ff06a 100644 --- a/docs/output.md +++ b/docs/output.md @@ -1,6 +1,6 @@ "Output" refers to the final value returned from [running an agent](agents.md#running-agents). This can be either plain text, [structured data](#structured-output), or the result of a [function](#output-functions) called with arguments provided by the model. -The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so that you can access other data, like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results). +The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so that you can access other data, like [usage][pydantic_ai.usage.RunUsage] of the run and [message history](message-history.md#accessing-messages-from-results). Both `AgentRunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. @@ -24,7 +24,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` _(This example is complete, it can be run "as is")_ diff --git a/docs/testing.md b/docs/testing.md index 49b3eba3ca..8d7cd8313d 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -97,7 +97,7 @@ from pydantic_ai.messages import ( UserPromptPart, ModelRequest, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from fake_database import DatabaseConn from weather_app import run_weather_forecast, weather_agent @@ -141,12 +141,9 @@ async def test_forecast(): tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=71, - response_tokens=7, - total_tokens=78, - details=None, + usage=RequestUsage( + input_tokens=71, + output_tokens=7, ), model_name='test', timestamp=IsNow(tz=timezone.utc), @@ -167,12 +164,9 @@ async def test_forecast(): content='{"weather_forecast":"Sunny with a chance of rain"}', ) ], - usage=Usage( - requests=1, - request_tokens=77, - response_tokens=16, - total_tokens=93, - details=None, + usage=RequestUsage( + input_tokens=77, + output_tokens=16, ), model_name='test', timestamp=IsNow(tz=timezone.utc), diff --git a/docs/tools.md b/docs/tools.md index 1639cb1ab1..1cd36ee4ef 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -95,7 +95,7 @@ print(dice_result.all_messages()) tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], - usage=Usage(requests=1, request_tokens=90, response_tokens=2, total_tokens=92), + usage=RequestUsage(input_tokens=90, output_tokens=2), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), @@ -115,7 +115,7 @@ print(dice_result.all_messages()) tool_name='get_player_name', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], - usage=Usage(requests=1, request_tokens=91, response_tokens=4, total_tokens=95), + usage=RequestUsage(input_tokens=91, output_tokens=4), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), @@ -135,9 +135,7 @@ print(dice_result.all_messages()) content="Congratulations Anne, you guessed correctly! You're a winner!" ) ], - usage=Usage( - requests=1, request_tokens=92, response_tokens=12, total_tokens=104 - ), + usage=RequestUsage(input_tokens=92, output_tokens=12), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), diff --git a/examples/pydantic_ai_examples/flight_booking.py b/examples/pydantic_ai_examples/flight_booking.py index 0b180db3f2..d920ded6d6 100644 --- a/examples/pydantic_ai_examples/flight_booking.py +++ b/examples/pydantic_ai_examples/flight_booking.py @@ -13,7 +13,7 @@ from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.messages import ModelMessage -from pydantic_ai.usage import Usage, UsageLimits +from pydantic_ai.usage import RunUsage, UsageLimits # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') @@ -182,7 +182,7 @@ async def main(): req_date=datetime.date(2025, 1, 10), ) message_history: list[ModelMessage] | None = None - usage: Usage = Usage() + usage: RunUsage = RunUsage() # run the agent until a satisfactory flight is found while True: result = await search_agent.run( @@ -213,7 +213,7 @@ async def main(): ) -async def find_seat(usage: Usage) -> SeatPreference: +async def find_seat(usage: RunUsage) -> SeatPreference: message_history: list[ModelMessage] | None = None while True: answer = Prompt.ask('What seat would you like?') diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 132d414b3d..1e8beaec87 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -76,7 +76,7 @@ class GraphAgentState: """State kept across the execution of the agent graph.""" message_history: list[_messages.ModelMessage] - usage: _usage.Usage + usage: _usage.RunUsage retries: int run_step: int @@ -337,7 +337,7 @@ async def _make_request( model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) - ctx.state.usage.incr(_usage.Usage()) + ctx.state.usage.requests += 1 return self._finish_handling(ctx, model_response) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index afad0e60e6..64694d7768 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from .models import Model - from .result import Usage + from .result import RunUsage AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) """Type variable for agent dependencies.""" @@ -26,7 +26,7 @@ class RunContext(Generic[AgentDepsT]): """Dependencies for the agent.""" model: Model """The model used in this run.""" - usage: Usage + usage: RunUsage """LLM usage associated with the run.""" prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 6416afd0cd..ca76947a06 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -51,7 +51,7 @@ from .tools import AgentDepsT, ToolDefinition from .toolsets import AbstractToolset from .toolsets.deferred import DeferredToolset -from .usage import Usage, UsageLimits +from .usage import RunUsage, UsageLimits try: from ag_ui.core import ( @@ -127,7 +127,7 @@ def __init__( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, # Starlette parameters. @@ -216,7 +216,7 @@ async def handle_ag_ui_request( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> Response: @@ -277,7 +277,7 @@ async def run_ag_ui( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncIterator[str]: diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index fb4c10b632..1eab46d7e8 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -431,7 +431,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -447,7 +447,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -463,7 +463,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: @@ -514,9 +514,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -575,7 +573,7 @@ async def main(): ) # Build the initial state - usage = usage or _usage.Usage() + usage = usage or _usage.RunUsage() state = _agent_graph.GraphAgentState( message_history=message_history[:] if message_history else [], usage=usage, @@ -677,7 +675,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: run_span.end() def _run_span_end_attributes( - self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings + self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings ): return { **usage.opentelemetry_attributes(), diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 8a2c685c19..4da85b039d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -32,7 +32,7 @@ ToolFuncEither, ) from ..toolsets import AbstractToolset -from ..usage import Usage, UsageLimits +from ..usage import RunUsage, UsageLimits # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -131,7 +131,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -148,7 +148,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -164,7 +164,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -240,7 +240,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -257,7 +257,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -273,7 +273,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -341,7 +341,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -358,7 +358,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -375,7 +375,7 @@ async def run_stream( # noqa C901 deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -534,7 +534,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -550,7 +550,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -567,7 +567,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: @@ -618,9 +618,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -678,7 +676,7 @@ def override( def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. - Usage should be `self._infer_name(inspect.currentframe())`. + RunUsage should be `self._infer_name(inspect.currentframe())`. """ assert self.name is None, 'Name already set' if function_frame is not None: # pragma: no branch @@ -751,7 +749,7 @@ def to_ag_ui( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, - usage: Usage | None = None, + usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, # Starlette diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 7554f26a2c..a65380b41e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -76,7 +76,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -92,7 +92,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -108,7 +108,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: @@ -159,9 +159,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 468150321c..5f315fe144 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -16,7 +16,7 @@ from datetime import datetime from types import TracebackType -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from pydantic_graph._utils import get_event_loop as _get_event_loop from . import agent, messages, models, settings @@ -57,7 +57,7 @@ async def main(): ''' ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) @@ -110,7 +110,7 @@ def model_request_sync( ''' ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) @@ -366,7 +366,7 @@ def get(self) -> messages.ModelResponse: """Build a ModelResponse from the data received from the stream so far.""" return self._ensure_stream_ready().get() - def usage(self) -> Usage: + def usage(self) -> RequestUsage: """Get the usage of the response so far.""" return self._ensure_stream_ready().usage() diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 23d8b42366..3c49831e71 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -200,7 +200,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -217,7 +217,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -233,7 +233,7 @@ async def run( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -305,7 +305,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -322,7 +322,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -338,7 +338,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -408,7 +408,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -425,7 +425,7 @@ def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -442,7 +442,7 @@ async def run_stream( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, @@ -513,7 +513,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -530,7 +530,7 @@ def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -547,7 +547,7 @@ async def iter( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, + usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, @@ -599,9 +599,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py index 6607aec077..2e1c1cc095 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py @@ -21,7 +21,7 @@ from pydantic_ai.models.wrapper import WrapperModel from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import AgentDepsT, RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ._run_context import TemporalRunContext @@ -48,7 +48,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: def get(self) -> ModelResponse: return self.response - def usage(self) -> Usage: + def usage(self) -> RequestUsage: return self.response.usage # pragma: no cover @property diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 1aaac328e4..28447187ef 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -19,7 +19,7 @@ now_utc as _now_utc, ) from .exceptions import UnexpectedModelBehavior -from .usage import Usage +from .usage import RequestUsage if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings @@ -830,7 +830,7 @@ class ModelResponse: parts: list[ModelResponsePart] """The parts of the model message.""" - usage: Usage = field(default_factory=Usage) + usage: RequestUsage = field(default_factory=RequestUsage) """Usage information for the request. This has a default to make tests easier, and to support loading old messages where usage will be missing. @@ -848,15 +848,15 @@ class ModelResponse: kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" - vendor_details: dict[str, Any] | None = field(default=None) - """Additional vendor-specific details in a serializable format. + provider_details: dict[str, Any] | None = field(default=None) + """Additional provider-specific details in a serializable format. This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. For OpenAI models, this may include 'logprobs', 'finish_reason', etc. """ - vendor_id: str | None = None - """Vendor ID as specified by the model provider. This can be used to track the specific request to the model.""" + provider_request_id: str | None = None + """request ID as specified by the model provider. This can be used to track the specific request to the model.""" def otel_events(self, settings: InstrumentationSettings) -> list[Event]: """Return OpenTelemetry events for the response.""" @@ -894,6 +894,16 @@ def new_event_body(): return result + @property + @deprecated('`vendor_details` is deprecated, use `provider_details` instead') + def vendor_details(self) -> dict[str, Any] | None: + return self.provider_details + + @property + @deprecated('`vendor_id` is deprecated, use `provider_request_id` instead') + def vendor_id(self) -> str | None: + return self.provider_request_id + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 00cda32c94..ff3d1c7ff6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -42,7 +42,7 @@ from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings from ..tools import ToolDefinition -from ..usage import Usage +from ..usage import RequestUsage KnownModelName = TypeAliasType( 'KnownModelName', @@ -418,7 +418,7 @@ async def count_tokens( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, - ) -> Usage: + ) -> RequestUsage: """Make a request to the model for counting tokens.""" # This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`. raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}') @@ -547,7 +547,7 @@ class StreamedResponse(ABC): _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) - _usage: Usage = field(default_factory=Usage, init=False) + _usage: RequestUsage = field(default_factory=RequestUsage, init=False) def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -600,7 +600,7 @@ def get(self) -> ModelResponse: usage=self.usage(), ) - def usage(self) -> Usage: + def usage(self) -> RequestUsage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index b65e319b4e..dbc35cfd98 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -180,7 +180,6 @@ async def request( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -325,7 +324,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: ) ) - return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id) + return ModelResponse( + items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id + ) async def _process_streamed_response( self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters @@ -528,7 +529,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam: } -def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage: +def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage: if isinstance(message, BetaMessage): response_usage = message.usage elif isinstance(message, BetaRawMessageStartEvent): @@ -541,7 +542,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage: # - RawContentBlockStartEvent # - RawContentBlockDeltaEvent # - RawContentBlockStopEvent - return usage.Usage() + return usage.RequestUsage() # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by # `response_tokens` @@ -552,17 +553,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage: # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get` # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens # This approach maintains request_tokens as the count of all input tokens, with cached counts as details - request_tokens = ( - details.get('input_tokens', 0) - + details.get('cache_creation_input_tokens', 0) - + details.get('cache_read_input_tokens', 0) - ) - - return usage.Usage( - request_tokens=request_tokens or None, - response_tokens=response_usage.output_tokens, - total_tokens=request_tokens + response_usage.output_tokens, - details=details or None, + cache_write_tokens = details.get('cache_creation_input_tokens', 0) + cache_read_tokens = details.get('cache_read_input_tokens', 0) + request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens + + return usage.RequestUsage( + input_tokens=request_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + output_tokens=response_usage.output_tokens, + details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index ef367868fd..995e7e8339 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -258,7 +258,6 @@ async def request( settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, False, settings, model_request_parameters) model_response = await self._process_response(response) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -299,13 +298,12 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes tool_call_id=tool_use['toolUseId'], ), ) - u = usage.Usage( - request_tokens=response['usage']['inputTokens'], - response_tokens=response['usage']['outputTokens'], - total_tokens=response['usage']['totalTokens'], + u = usage.RequestUsage( + input_tokens=response['usage']['inputTokens'], + output_tokens=response['usage']['outputTokens'], ) vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None) - return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id) + return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id) @overload async def _messages_create( @@ -670,11 +668,10 @@ def model_name(self) -> str: """Get the model name of the response.""" return self._model_name - def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.Usage: - return usage.Usage( - request_tokens=metadata['usage']['inputTokens'], - response_tokens=metadata['usage']['outputTokens'], - total_tokens=metadata['usage']['totalTokens'], + def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage: + return usage.RequestUsage( + input_tokens=metadata['usage']['inputTokens'], + output_tokens=metadata['usage']['outputTokens'], ) diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 42ac30e321..c9ed3a8566 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -149,7 +149,6 @@ async def request( check_allow_model_requests() response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) - model_response.usage.requests = 1 return model_response @property @@ -301,10 +300,10 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: assert_never(part) -def _map_usage(response: V2ChatResponse) -> usage.Usage: +def _map_usage(response: V2ChatResponse) -> usage.RequestUsage: u = response.usage if u is None: - return usage.Usage() + return usage.RequestUsage() else: details: dict[str, int] = {} if u.billed_units is not None: @@ -317,11 +316,10 @@ def _map_usage(response: V2ChatResponse) -> usage.Usage: if u.billed_units.classifications: # pragma: no cover details['classifications'] = int(u.billed_units.classifications) - request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else None - response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else None - return usage.Usage( - request_tokens=request_tokens, - response_tokens=response_tokens, - total_tokens=(request_tokens or 0) + (response_tokens or 0), + request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0 + response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0 + return usage.RequestUsage( + input_tokens=request_tokens, + output_tokens=response_tokens, details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 65ff27c96f..b3062c8c3f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -138,7 +138,6 @@ async def request( # Add usage data if not already present if not response.usage.has_values(): # pragma: no branch response.usage = _estimate_usage(chain(messages, [response])) - response.usage.requests = 1 return response @asynccontextmanager @@ -270,7 +269,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._usage += usage.RequestUsage(output_tokens=response_tokens) maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) if maybe_event is not None: # pragma: no branch yield maybe_event @@ -279,7 +278,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(delta, DeltaThinkingPart): if delta.content: # pragma: no branch response_tokens = _estimate_string_tokens(delta.content) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._usage += usage.RequestUsage(output_tokens=response_tokens) yield self._parts_manager.handle_thinking_delta( vendor_part_id=dtc_index, content=delta.content, @@ -288,7 +287,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(delta, DeltaToolCall): if delta.json_args: response_tokens = _estimate_string_tokens(delta.json_args) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._usage += usage.RequestUsage(output_tokens=response_tokens) maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc_index, tool_name=delta.name, @@ -311,7 +310,7 @@ def timestamp(self) -> datetime: return self._timestamp -def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: +def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage: """Very rough guesstimate of the token usage associated with a series of messages. This is designed to be used solely to give plausible numbers for testing! @@ -349,10 +348,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: assert_never(part) else: assert_never(message) - return usage.Usage( - request_tokens=request_tokens, - response_tokens=response_tokens, - total_tokens=request_tokens + response_tokens, + return usage.RequestUsage( + input_tokens=request_tokens, + output_tokens=response_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 48fc811ece..0e223cb190 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -278,7 +278,6 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: if finish_reason: vendor_details = {'finish_reason': finish_reason} usage = _metadata_as_usage(response) - usage.requests = 1 return _process_response_from_parts( parts, response.get('model_version', self._model_name), @@ -673,7 +672,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart def _process_response_from_parts( parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, - usage: usage.Usage, + usage: usage.RequestUsage, vendor_id: str | None, vendor_details: dict[str, Any] | None = None, ) -> ModelResponse: @@ -693,7 +692,7 @@ def _process_response_from_parts( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) return ModelResponse( - parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details + parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details ) @@ -859,31 +858,45 @@ class _GeminiUsageMetaData(TypedDict, total=False): ] -def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: +def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage: metadata = response.get('usage_metadata') if metadata is None: - return usage.Usage() # pragma: no cover + return usage.RequestUsage() details: dict[str, int] = {} - if cached_content_token_count := metadata.get('cached_content_token_count'): - details['cached_content_tokens'] = cached_content_token_count # pragma: no cover + if cached_content_token_count := metadata.get('cached_content_token_count', 0): + details['cached_content_tokens'] = cached_content_token_count - if thoughts_token_count := metadata.get('thoughts_token_count'): + if thoughts_token_count := metadata.get('thoughts_token_count', 0): details['thoughts_tokens'] = thoughts_token_count - if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): - details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover + if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0): + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count + input_audio_tokens = 0 + output_audio_tokens = 0 + cache_audio_read_tokens = 0 for key, metadata_details in metadata.items(): if key.endswith('_details') and metadata_details: metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) suffix = key.removesuffix('_details') for detail in metadata_details: - details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0) - - return usage.Usage( - request_tokens=metadata.get('prompt_token_count', 0), - response_tokens=metadata.get('candidates_token_count', 0), - total_tokens=metadata.get('total_token_count', 0), + modality = detail['modality'] + details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0) + if value and modality == 'AUDIO': + if key == 'prompt_tokens_details': + input_audio_tokens = value + elif key == 'candidates_tokens_details': + output_audio_tokens = value + elif key == 'cache_tokens_details': # pragma: no branch + cache_audio_read_tokens = value + + return usage.RequestUsage( + input_tokens=metadata.get('prompt_token_count', 0), + output_tokens=metadata.get('candidates_token_count', 0), + cache_read_tokens=cached_content_token_count, + input_audio_tokens=input_audio_tokens, + output_audio_tokens=output_audio_tokens, + cache_audio_read_tokens=cache_audio_read_tokens, details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 4b53a79565..a2f1689c71 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -195,7 +195,7 @@ async def count_tokens( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, - ) -> usage.Usage: + ) -> usage.RequestUsage: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) contents, generation_config = await self._build_content_and_config( @@ -238,9 +238,8 @@ async def count_tokens( raise UnexpectedModelBehavior( # pragma: no cover 'Total tokens missing from Gemini response', str(response) ) - return usage.Usage( - request_tokens=response.total_tokens, - total_tokens=response.total_tokens, + return usage.RequestUsage( + input_tokens=response.total_tokens, ) @asynccontextmanager @@ -392,9 +391,12 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: if finish_reason: # pragma: no branch vendor_details = {'finish_reason': finish_reason.value} usage = _metadata_as_usage(response) - usage.requests = 1 return _process_response_from_parts( - parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details + parts, + response.model_version or self._model_name, + usage, + vendor_id=vendor_id, + vendor_details=vendor_details, ) async def _process_streamed_response( @@ -590,7 +592,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict: def _process_response_from_parts( parts: list[Part], model_name: GoogleModelName, - usage: usage.Usage, + usage: usage.RequestUsage, vendor_id: str | None, vendor_details: dict[str, Any] | None = None, ) -> ModelResponse: @@ -627,7 +629,7 @@ def _process_response_from_parts( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) return ModelResponse( - parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details + parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details ) @@ -647,31 +649,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict: return ToolConfigDict(function_calling_config=function_calling_config) -def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage: +def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage: metadata = response.usage_metadata if metadata is None: - return usage.Usage() # pragma: no cover - metadata = metadata.model_dump(exclude_defaults=True) - + return usage.RequestUsage() details: dict[str, int] = {} - if cached_content_token_count := metadata.get('cached_content_token_count'): - details['cached_content_tokens'] = cached_content_token_count # pragma: no cover + if cached_content_token_count := metadata.cached_content_token_count: + details['cached_content_tokens'] = cached_content_token_count - if thoughts_token_count := metadata.get('thoughts_token_count'): + if thoughts_token_count := metadata.thoughts_token_count: details['thoughts_tokens'] = thoughts_token_count - if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): + if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count: details['tool_use_prompt_tokens'] = tool_use_prompt_token_count - for key, metadata_details in metadata.items(): - if key.endswith('_details') and metadata_details: - suffix = key.removesuffix('_details') - for detail in metadata_details: - details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0) - - return usage.Usage( - request_tokens=metadata.get('prompt_token_count', 0), - response_tokens=metadata.get('candidates_token_count', 0), - total_tokens=metadata.get('total_token_count', 0), + input_audio_tokens = 0 + output_audio_tokens = 0 + cache_audio_read_tokens = 0 + for prefix, metadata_details in [ + ('prompt', metadata.prompt_tokens_details), + ('cache', metadata.cache_tokens_details), + ('candidates', metadata.candidates_tokens_details), + ('tool_use_prompt', metadata.tool_use_prompt_tokens_details), + ]: + assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details + if not metadata_details: + continue + for detail in metadata_details: + if not detail.modality or not detail.token_count: # pragma: no cover + continue + details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count + if detail.modality != 'AUDIO': + continue + if metadata_details is metadata.prompt_tokens_details: + input_audio_tokens = detail.token_count + elif metadata_details is metadata.candidates_tokens_details: + output_audio_tokens = detail.token_count + elif metadata_details is metadata.cache_tokens_details: # pragma: no branch + cache_audio_read_tokens = detail.token_count + + return usage.RequestUsage( + input_tokens=metadata.prompt_token_count or 0, + output_tokens=metadata.candidates_token_count or 0, + cache_read_tokens=cached_content_token_count or 0, + input_audio_tokens=input_audio_tokens, + output_audio_tokens=output_audio_tokens, + cache_audio_read_tokens=cache_audio_read_tokens, details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index f076fdbe15..1b3307f532 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -162,7 +162,6 @@ async def request( messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -285,7 +284,11 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: for c in choice.message.tool_calls: items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) return ModelResponse( - items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + provider_request_id=response.id, ) async def _process_streamed_response( @@ -484,7 +487,7 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage: +def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage: response_usage = None if isinstance(completion, chat.ChatCompletion): response_usage = completion.usage @@ -492,10 +495,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us response_usage = completion.x_groq.usage if response_usage is None: - return usage.Usage() + return usage.RequestUsage() - return usage.Usage( - request_tokens=response_usage.prompt_tokens, - response_tokens=response_usage.completion_tokens, - total_tokens=response_usage.total_tokens, + return usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 3f44ac267d..323bee0add 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -152,7 +152,6 @@ async def request( messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -272,7 +271,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: usage=_map_usage(response), model_name=response.model, timestamp=timestamp, - vendor_id=response.id, + provider_request_id=response.id, ) async def _process_streamed_response( @@ -481,14 +480,12 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage: +def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage: response_usage = response.usage if response_usage is None: - return usage.Usage() + return usage.RequestUsage() - return usage.Usage( - request_tokens=response_usage.prompt_tokens, - response_tokens=response_usage.completion_tokens, - total_tokens=response_usage.total_tokens, - details=None, + return usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 2d68aa9001..b7bf965b9c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -280,14 +280,14 @@ def _record_metrics(): 'gen_ai.request.model': request_model, 'gen_ai.response.model': response_model, } - if response.usage.request_tokens: # pragma: no branch + if response.usage.input_tokens: # pragma: no branch self.instrumentation_settings.tokens_histogram.record( - response.usage.request_tokens, + response.usage.input_tokens, {**metric_attributes, 'gen_ai.token.type': 'input'}, ) - if response.usage.response_tokens: # pragma: no branch + if response.usage.output_tokens: # pragma: no branch self.instrumentation_settings.tokens_histogram.record( - response.usage.response_tokens, + response.usage.output_tokens, {**metric_attributes, 'gen_ai.token.type': 'output'}, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index a4f6497866..98949327ec 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast -from .. import _mcp, exceptions, usage +from .. import _mcp, exceptions from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..settings import ModelSettings @@ -63,7 +63,6 @@ async def request( if result.role == 'assistant': return ModelResponse( parts=[_mcp.map_from_sampling_content(result.content)], - usage=usage.Usage(requests=1), model_name=result.model, ) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 18228cdae1..7b5558af11 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -40,7 +40,7 @@ from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition -from ..usage import Usage +from ..usage import RequestUsage from . import ( Model, ModelRequestParameters, @@ -167,7 +167,6 @@ async def request( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -348,7 +347,11 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes parts.append(tool) return ModelResponse( - parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id + parts, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + provider_request_id=response.id, ) async def _process_streamed_response( @@ -699,17 +702,15 @@ def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[ } -def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage: +def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage: """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage.""" if response.usage: - return Usage( - request_tokens=response.usage.prompt_tokens, - response_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - details=None, + return RequestUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, ) else: - return Usage() # pragma: no cover + return RequestUsage() # pragma: no cover def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 6a4ae2549c..ee1ddfb9c1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -261,7 +261,6 @@ async def request( messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -445,8 +444,8 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons usage=_map_usage(response), model_name=response.model, timestamp=timestamp, - vendor_details=vendor_details, - vendor_id=response.id, + provider_details=vendor_details, + provider_request_id=response.id, ) async def _process_streamed_response( @@ -747,7 +746,7 @@ def _process_response(self, response: responses.Response) -> ModelResponse: items, usage=_map_usage(response), model_name=response.model, - vendor_id=response.id, + provider_request_id=response.id, timestamp=timestamp, ) @@ -1265,10 +1264,10 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.Usage: +def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.RequestUsage: response_usage = response.usage if response_usage is None: - return usage.Usage() + return usage.RequestUsage() elif isinstance(response_usage, responses.ResponseUsage): details: dict[str, int] = { key: value @@ -1278,29 +1277,29 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R if isinstance(value, int) } details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens - details['cached_tokens'] = response_usage.input_tokens_details.cached_tokens - return usage.Usage( - request_tokens=response_usage.input_tokens, - response_tokens=response_usage.output_tokens, - total_tokens=response_usage.total_tokens, + return usage.RequestUsage( + input_tokens=response_usage.input_tokens, + output_tokens=response_usage.output_tokens, + cache_read_tokens=response_usage.input_tokens_details.cached_tokens, details=details, ) else: details = { key: value for key, value in response_usage.model_dump( - exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'} + exclude_none=True, exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'} ).items() if isinstance(value, int) } + u = usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, + details=details, + ) if response_usage.completion_tokens_details is not None: details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) + u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0 if response_usage.prompt_tokens_details is not None: - details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True)) - return usage.Usage( - requests=1, - request_tokens=response_usage.prompt_tokens, - response_tokens=response_usage.completion_tokens, - total_tokens=response_usage.total_tokens, - details=details, - ) + u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0 + u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0 + return u diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 7685f24ddb..dbb5c68650 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -31,7 +31,7 @@ from ..profiles import ModelProfileSpec from ..settings import ModelSettings from ..tools import ToolDefinition -from ..usage import Usage +from ..usage import RequestUsage from . import Model, ModelRequestParameters, StreamedResponse from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage] @@ -113,7 +113,6 @@ async def request( self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) model_response.usage = _estimate_usage([*messages, model_response]) - model_response.usage.requests = 1 return model_response @asynccontextmanager @@ -468,6 +467,6 @@ def _char(self) -> str: return s -def _get_string_usage(text: str) -> Usage: +def _get_string_usage(text: str) -> RequestUsage: response_tokens = _estimate_string_tokens(text) - return Usage(response_tokens=response_tokens, total_tokens=response_tokens) + return RequestUsage(output_tokens=response_tokens) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 3f43ddb25f..39aa5395f6 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -27,7 +27,7 @@ OutputDataT, ToolOutput, ) -from .usage import Usage, UsageLimits +from .usage import RunUsage, UsageLimits __all__ = ( 'OutputDataT', @@ -52,7 +52,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _tool_manager: ToolManager[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) - _initial_run_ctx_usage: Usage = field(init=False) + _initial_run_ctx_usage: RunUsage = field(init=False) def __post_init__(self): self._initial_run_ctx_usage = copy(self._run_ctx.usage) @@ -110,7 +110,7 @@ def get(self) -> _messages.ModelResponse: """Get the current state of the response.""" return self._raw_stream_response.get() - def usage(self) -> Usage: + def usage(self) -> RunUsage: """Return the usage of the whole run. !!! note @@ -382,7 +382,7 @@ async def get_output(self) -> OutputDataT: await self._marked_completed(self._stream_response.get()) return output - def usage(self) -> Usage: + def usage(self) -> RunUsage: """Return the usage of the whole run. !!! note @@ -425,7 +425,7 @@ class FinalResult(Generic[OutputDataT]): def _get_usage_checking_stream_response( stream_response: models.StreamedResponse, limits: UsageLimits | None, - get_usage: Callable[[], Usage], + get_usage: Callable[[], RunUsage], ) -> AsyncIterator[AgentStreamEvent]: if limits is not None and limits.has_token_limits(): diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index f5d5d0ed83..e5908ce535 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -66,9 +66,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, request_tokens=56, response_tokens=7, total_tokens=63 - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -203,12 +201,7 @@ async def main(): CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, - request_tokens=56, - response_tokens=7, - total_tokens=63, - ), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), ) @@ -235,7 +228,7 @@ async def main(): assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node - def usage(self) -> _usage.Usage: + def usage(self) -> _usage.RunUsage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage @@ -352,6 +345,6 @@ def new_messages_json(self, *, output_tool_return_content: str | None = None) -> self.new_messages(output_tool_return_content=output_tool_return_content) ) - def usage(self) -> _usage.Usage: + def usage(self) -> _usage.RunUsage: """Return the usage of the whole run.""" return self._state.usage diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 9fc3f5bc00..a16f4ee09f 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -1,67 +1,62 @@ from __future__ import annotations as _annotations +import dataclasses from copy import copy -from dataclasses import dataclass +from dataclasses import dataclass, fields + +from typing_extensions import deprecated, overload from . import _utils from .exceptions import UsageLimitExceeded -__all__ = 'Usage', 'UsageLimits' +__all__ = 'RequestUsage', 'RunUsage', 'Usage', 'UsageLimits' @dataclass(repr=False) -class Usage: - """LLM usage associated with a request or run. - - Responsibility for calculating usage is on the model; Pydantic AI simply sums the usage information across requests. +class UsageBase: + input_tokens: int = 0 + """Number of input/prompt tokens.""" - You'll need to look up the documentation of the model you're using to convert usage to monetary costs. - """ + cache_write_tokens: int = 0 + """Number of tokens written to the cache.""" + cache_read_tokens: int = 0 + """Number of tokens read from the cache.""" - requests: int = 0 - """Number of requests made to the LLM API.""" - request_tokens: int | None = None - """Tokens used in processing requests.""" - response_tokens: int | None = None - """Tokens used in generating responses.""" - total_tokens: int | None = None - """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" - details: dict[str, int] | None = None - """Any extra details returned by the model.""" + output_tokens: int = 0 + """Number of output/completion tokens.""" - def incr(self, incr_usage: Usage) -> None: - """Increment the usage in place. + input_audio_tokens: int = 0 + """Number of audio input tokens.""" + cache_audio_read_tokens: int = 0 + """Number of audio tokens read from the cache.""" + output_audio_tokens: int = 0 + """Number of audio output tokens.""" - Args: - incr_usage: The usage to increment by. - """ - for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': - self_value = getattr(self, f) - other_value = getattr(incr_usage, f) - if self_value is not None or other_value is not None: - setattr(self, f, (self_value or 0) + (other_value or 0)) + details: dict[str, int] = dataclasses.field(default_factory=dict) + """Any extra details returned by the model.""" - if incr_usage.details: - self.details = self.details or {} - for key, value in incr_usage.details.items(): - self.details[key] = self.details.get(key, 0) + value + @property + @deprecated('`request_tokens` is deprecated, use `input_tokens` instead') + def request_tokens(self) -> int: + return self.input_tokens - def __add__(self, other: Usage) -> Usage: - """Add two Usages together. + @property + @deprecated('`response_tokens` is deprecated, use `output_tokens` instead') + def response_tokens(self) -> int: + return self.output_tokens - This is provided so it's trivial to sum usage information from multiple requests and runs. - """ - new_usage = copy(self) - new_usage.incr(other) - return new_usage + @property + def total_tokens(self) -> int: + """Sum of `input_tokens + output_tokens`.""" + return self.input_tokens + self.output_tokens def opentelemetry_attributes(self) -> dict[str, int]: - """Get the token limits as OpenTelemetry attributes.""" + """Get the token usage values as OpenTelemetry attributes.""" result: dict[str, int] = {} - if self.request_tokens: - result['gen_ai.usage.input_tokens'] = self.request_tokens - if self.response_tokens: - result['gen_ai.usage.output_tokens'] = self.response_tokens + if self.input_tokens: + result['gen_ai.usage.input_tokens'] = self.input_tokens + if self.output_tokens: + result['gen_ai.usage.output_tokens'] = self.output_tokens details = self.details if details: prefix = 'gen_ai.usage.details.' @@ -71,11 +66,118 @@ def opentelemetry_attributes(self) -> dict[str, int]: result[prefix + key] = value return result + def __repr__(self): + kv_pairs = (f'{f.name}={value!r}' for f in fields(self) if (value := getattr(self, f.name))) + return f'{self.__class__.__qualname__}({", ".join(kv_pairs)})' + def has_values(self) -> bool: """Whether any values are set and non-zero.""" - return bool(self.requests or self.request_tokens or self.response_tokens or self.details) + return any(dataclasses.asdict(self).values()) - __repr__ = _utils.dataclasses_no_defaults_repr + +@dataclass(repr=False) +class RequestUsage(UsageBase): + """LLM usage associated with a single request. + + This is an implementation of `genai_prices.types.AbstractUsage` so it can be used to calculate the price of the + request using [genai-prices](https://github.com/pydantic/genai-prices). + """ + + @property + def requests(self): + return 1 + + def incr(self, incr_usage: RequestUsage) -> None: + """Increment the usage in place. + + Args: + incr_usage: The usage to increment by. + """ + return _incr_usage_tokens(self, incr_usage) + + def __add__(self, other: RequestUsage) -> RequestUsage: + """Add two RequestUsages together. + + This is provided so it's trivial to sum usage information from multiple parts of a response. + + **WARNING:** this CANNOT be used to sum multiple requests without breaking some pricing calculations. + """ + new_usage = copy(self) + new_usage.incr(other) + return new_usage + + +@dataclass(repr=False) +class RunUsage(UsageBase): + """LLM usage associated with an agent run. + + Responsibility for calculating request usage is on the model; Pydantic AI simply sums the usage information across requests. + """ + + requests: int = 0 + """Number of requests made to the LLM API.""" + + input_tokens: int = 0 + """Total number of text input/prompt tokens.""" + + cache_write_tokens: int = 0 + """Total number of tokens written to the cache.""" + cache_read_tokens: int = 0 + """Total number of tokens read from the cache.""" + + input_audio_tokens: int = 0 + """Total number of audio input tokens.""" + cache_audio_read_tokens: int = 0 + """Total number of audio tokens read from the cache.""" + + output_tokens: int = 0 + """Total number of text output/completion tokens.""" + + details: dict[str, int] = dataclasses.field(default_factory=dict) + """Any extra details returned by the model.""" + + def incr(self, incr_usage: RunUsage | RequestUsage) -> None: + """Increment the usage in place. + + Args: + incr_usage: The usage to increment by. + """ + if isinstance(incr_usage, RunUsage): + self.requests += incr_usage.requests + return _incr_usage_tokens(self, incr_usage) + + def __add__(self, other: RunUsage | RequestUsage) -> RunUsage: + """Add two RunUsages together. + + This is provided so it's trivial to sum usage information from multiple runs. + """ + new_usage = copy(self) + new_usage.incr(other) + return new_usage + + +def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | RequestUsage) -> None: + """Increment the usage in place. + + Args: + slf: The usage to increment. + incr_usage: The usage to increment by. + """ + slf.input_tokens += incr_usage.input_tokens + slf.cache_write_tokens += incr_usage.cache_write_tokens + slf.cache_read_tokens += incr_usage.cache_read_tokens + slf.input_audio_tokens += incr_usage.input_audio_tokens + slf.cache_audio_read_tokens += incr_usage.cache_audio_read_tokens + slf.output_tokens += incr_usage.output_tokens + + for key, value in incr_usage.details.items(): + slf.details[key] = slf.details.get(key, 0) + value + + +@dataclass +@deprecated('`Usage` is deprecated, use `RunUsage` instead') +class Usage(RunUsage): + """Deprecated alias for `RunUsage`.""" @dataclass(repr=False) @@ -90,10 +192,10 @@ class UsageLimits: request_limit: int | None = 50 """The maximum number of requests allowed to the model.""" - request_tokens_limit: int | None = None - """The maximum number of tokens allowed in requests to the model.""" - response_tokens_limit: int | None = None - """The maximum number of tokens allowed in responses from the model.""" + input_tokens_limit: int | None = None + """The maximum number of input/prompt tokens allowed.""" + output_tokens_limit: int | None = None + """The maximum number of output/response tokens allowed.""" total_tokens_limit: int | None = None """The maximum number of tokens allowed in requests and responses combined.""" count_tokens_before_request: bool = False @@ -101,6 +203,69 @@ class UsageLimits: to enforce `request_tokens_limit` ahead of time. This may incur additional overhead (from calling the model's `count_tokens` API before making the actual request) and is disabled by default.""" + @property + @deprecated('`request_tokens_limit` is deprecated, use `input_tokens_limit` instead') + def request_tokens_limit(self) -> int | None: + return self.input_tokens_limit + + @property + @deprecated('`response_tokens_limit` is deprecated, use `output_tokens_limit` instead') + def response_tokens_limit(self) -> int | None: + return self.output_tokens_limit + + @overload + def __init__( + self, + *, + request_limit: int | None = 50, + input_tokens_limit: int | None = None, + output_tokens_limit: int | None = None, + total_tokens_limit: int | None = None, + count_tokens_before_request: bool = False, + ) -> None: + self.request_limit = request_limit + self.input_tokens_limit = input_tokens_limit + self.output_tokens_limit = output_tokens_limit + self.total_tokens_limit = total_tokens_limit + self.count_tokens_before_request = count_tokens_before_request + + @overload + @deprecated( + 'Use `input_tokens_limit` instead of `request_tokens_limit` and `output_tokens_limit` and `total_tokens_limit`' + ) + def __init__( + self, + *, + request_limit: int | None = 50, + request_tokens_limit: int | None = None, + response_tokens_limit: int | None = None, + total_tokens_limit: int | None = None, + count_tokens_before_request: bool = False, + ) -> None: + self.request_limit = request_limit + self.input_tokens_limit = request_tokens_limit + self.output_tokens_limit = response_tokens_limit + self.total_tokens_limit = total_tokens_limit + self.count_tokens_before_request = count_tokens_before_request + + def __init__( + self, + *, + request_limit: int | None = 50, + input_tokens_limit: int | None = None, + output_tokens_limit: int | None = None, + total_tokens_limit: int | None = None, + count_tokens_before_request: bool = False, + # deprecated: + request_tokens_limit: int | None = None, + response_tokens_limit: int | None = None, + ): + self.request_limit = request_limit + self.input_tokens_limit = input_tokens_limit or request_tokens_limit + self.output_tokens_limit = output_tokens_limit or response_tokens_limit + self.total_tokens_limit = total_tokens_limit + self.count_tokens_before_request = count_tokens_before_request + def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. @@ -110,43 +275,40 @@ def has_token_limits(self) -> bool: If there are no limits, we can skip that processing in the streaming response iterator. """ return any( - limit is not None - for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) + limit is not None for limit in (self.input_tokens_limit, self.output_tokens_limit, self.total_tokens_limit) ) - def check_before_request(self, usage: Usage) -> None: + def check_before_request(self, usage: RunUsage) -> None: """Raises a `UsageLimitExceeded` exception if the next request would exceed any of the limits.""" request_limit = self.request_limit if request_limit is not None and usage.requests >= request_limit: raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') - request_tokens = usage.request_tokens or 0 - if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: + input_tokens = usage.input_tokens + if self.input_tokens_limit is not None and input_tokens > self.input_tokens_limit: raise UsageLimitExceeded( - f'The next request would exceed the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' + f'The next request would exceed the input_tokens_limit of {self.input_tokens_limit} ({input_tokens=})' ) - total_tokens = usage.total_tokens or 0 + total_tokens = usage.total_tokens if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded( f'The next request would exceed the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})' ) - def check_tokens(self, usage: Usage) -> None: + def check_tokens(self, usage: RunUsage) -> None: """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" - request_tokens = usage.request_tokens or 0 - if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: - raise UsageLimitExceeded( - f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' - ) + input_tokens = usage.input_tokens + if self.input_tokens_limit is not None and input_tokens > self.input_tokens_limit: + raise UsageLimitExceeded(f'Exceeded the input_tokens_limit of {self.input_tokens_limit} ({input_tokens=})') - response_tokens = usage.response_tokens or 0 - if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: + output_tokens = usage.output_tokens + if self.output_tokens_limit is not None and output_tokens > self.output_tokens_limit: raise UsageLimitExceeded( - f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' + f'Exceeded the output_tokens_limit of {self.output_tokens_limit} ({output_tokens=})' ) - total_tokens = usage.total_tokens or 0 + total_tokens = usage.total_tokens if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 65ef56d789..8a400ecd64 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -46,6 +46,12 @@ classifiers = [ ] requires-python = ">=3.9" +[project.urls] +Homepage = "https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim" +Source = "https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim" +Documentation = "https://ai.pydantic.dev/install/#slim-install" +Changelog = "https://github.com/pydantic/pydantic-ai/releases" + [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ "eval-type-backport>=0.2.0", @@ -56,6 +62,7 @@ dependencies = [ "exceptiongroup; python_version < '3.11'", "opentelemetry-api>=1.28.0", "typing-inspection>=0.4.0", + "genai-prices>=0.0.22", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] @@ -93,7 +100,7 @@ temporal = ["temporalio>=1.15.0"] allow-direct-references = true [project.scripts] -pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while +pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while [tool.hatch.build.targets.wheel] packages = ["pydantic_ai"] diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 2f24b95f1b..f1b2f43be3 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -39,8 +39,9 @@ UserPromptPart, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, raise_if_exception, try_import from ..parts_from_messages import part_types_from_messages @@ -173,11 +174,10 @@ async def test_sync_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=5, - response_tokens=10, - total_tokens=15, + input_tokens=5, + output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}, ) ) @@ -187,11 +187,10 @@ async def test_sync_request_text_response(allow_model_requests: None): result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=5, - response_tokens=10, - total_tokens=15, + input_tokens=5, + output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}, ) ) @@ -200,30 +199,18 @@ async def test_sync_request_text_response(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage( - requests=1, - request_tokens=5, - response_tokens=10, - total_tokens=15, - details={'input_tokens': 5, 'output_tokens': 10}, - ), + usage=RequestUsage(input_tokens=5, output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage( - requests=1, - request_tokens=5, - response_tokens=10, - total_tokens=15, - details={'input_tokens': 5, 'output_tokens': 10}, - ), + usage=RequestUsage(input_tokens=5, output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -246,11 +233,12 @@ async def test_async_request_prompt_caching(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=13, - response_tokens=5, - total_tokens=18, + input_tokens=13, + cache_write_tokens=4, + cache_read_tokens=6, + output_tokens=5, details={ 'input_tokens': 3, 'output_tokens': 5, @@ -259,6 +247,8 @@ async def test_async_request_prompt_caching(allow_model_requests: None): }, ) ) + last_message = result.all_messages()[-1] + assert isinstance(last_message, ModelResponse) async def test_async_request_text_response(allow_model_requests: None): @@ -273,11 +263,10 @@ async def test_async_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=3, - response_tokens=5, - total_tokens=8, + input_tokens=3, + output_tokens=5, details={'input_tokens': 3, 'output_tokens': 5}, ) ) @@ -305,16 +294,10 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage( - requests=1, - request_tokens=3, - response_tokens=5, - total_tokens=8, - details={'input_tokens': 3, 'output_tokens': 5}, - ), + usage=RequestUsage(input_tokens=3, output_tokens=5, details={'input_tokens': 3, 'output_tokens': 5}), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -375,16 +358,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage( - requests=1, - request_tokens=2, - response_tokens=1, - total_tokens=3, - details={'input_tokens': 2, 'output_tokens': 1}, - ), + usage=RequestUsage(input_tokens=2, output_tokens=1, details={'input_tokens': 2, 'output_tokens': 1}), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -404,16 +381,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage( - requests=1, - request_tokens=3, - response_tokens=2, - total_tokens=5, - details={'input_tokens': 3, 'output_tokens': 2}, - ), + usage=RequestUsage(input_tokens=3, output_tokens=2, details={'input_tokens': 3, 'output_tokens': 2}), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -427,16 +398,10 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage( - requests=1, - request_tokens=3, - response_tokens=5, - total_tokens=8, - details={'input_tokens': 3, 'output_tokens': 5}, - ), + usage=RequestUsage(input_tokens=3, output_tokens=5, details={'input_tokens': 3, 'output_tokens': 5}), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -690,11 +655,10 @@ async def my_tool(first: str, second: str) -> int: ) assert result.is_complete assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=20, - response_tokens=5, - total_tokens=25, + input_tokens=20, + output_tokens=5, details={'input_tokens': 20, 'output_tokens': 5}, ) ) @@ -771,11 +735,9 @@ async def get_image() -> BinaryContent: TextPart(content='Let me get the image and check what fruit is shown.'), ToolCallPart(tool_name='get_image', args={}, tool_call_id='toolu_01WALUz3dC75yywrmL6dF3Bc'), ], - usage=Usage( - requests=1, - request_tokens=372, - response_tokens=49, - total_tokens=421, + usage=RequestUsage( + input_tokens=372, + output_tokens=49, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -785,7 +747,7 @@ async def get_image() -> BinaryContent: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01Kwjzggomz7bv9og51qGFuH', + provider_request_id='msg_01Kwjzggomz7bv9og51qGFuH', ), ModelRequest( parts=[ @@ -810,11 +772,9 @@ async def get_image() -> BinaryContent: content="The image shows a kiwi fruit that has been cut in half, displaying its characteristic bright green flesh with small black seeds arranged in a circular pattern around a white center core. The kiwi's flesh has the typical fuzzy brown skin visible around the edges. The image is a clean, well-lit close-up shot of the kiwi slice against a white background." ) ], - usage=Usage( - requests=1, - request_tokens=2025, - response_tokens=81, - total_tokens=2106, + usage=RequestUsage( + input_tokens=2025, + output_tokens=81, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -824,7 +784,7 @@ async def get_image() -> BinaryContent: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_015btMBYLTuDnMP7zAeuHQGi', + provider_request_id='msg_015btMBYLTuDnMP7zAeuHQGi', ), ] ) @@ -933,11 +893,9 @@ def simple_instructions(): ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, - request_tokens=20, - response_tokens=10, - total_tokens=30, + usage=RequestUsage( + input_tokens=20, + output_tokens=10, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -947,7 +905,7 @@ def simple_instructions(): ), model_name='claude-3-opus-20240229', timestamp=IsDatetime(), - vendor_id='msg_01Fg1JVgvCYUHWsxrj9GkpEv', + provider_request_id='msg_01Fg1JVgvCYUHWsxrj9GkpEv', ), ] ) @@ -982,11 +940,9 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), TextPart(content=IsStr()), ], - usage=Usage( - requests=1, - request_tokens=42, - response_tokens=363, - total_tokens=405, + usage=RequestUsage( + input_tokens=42, + output_tokens=363, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -996,7 +952,7 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), - vendor_id='msg_01BnZvs3naGorn93wjjCDwbd', + provider_request_id='msg_01BnZvs3naGorn93wjjCDwbd', ), ] ) @@ -1010,11 +966,9 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage( - requests=1, - request_tokens=42, - response_tokens=363, - total_tokens=405, + usage=RequestUsage( + input_tokens=42, + output_tokens=363, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1024,7 +978,7 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1054,11 +1008,9 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), TextPart(content=IsStr()), ], - usage=Usage( - requests=1, - request_tokens=291, - response_tokens=471, - total_tokens=762, + usage=RequestUsage( + input_tokens=291, + output_tokens=471, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1068,7 +1020,7 @@ async def test_anthropic_model_thinking_part(allow_model_requests: None, anthrop ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ] ) @@ -1286,11 +1238,7 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: [ pytest.param( lambda: anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), - snapshot( - Usage( - request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} - ) - ), + snapshot(RequestUsage(input_tokens=1, output_tokens=1, details={'input_tokens': 1, 'output_tokens': 1})), id='AnthropicMessage', ), pytest.param( @@ -1298,10 +1246,11 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: BetaUsage(input_tokens=1, output_tokens=1, cache_creation_input_tokens=2, cache_read_input_tokens=3) ), snapshot( - Usage( - request_tokens=6, - response_tokens=1, - total_tokens=7, + RequestUsage( + input_tokens=6, + cache_write_tokens=2, + cache_read_tokens=3, + output_tokens=1, details={ 'cache_creation_input_tokens': 2, 'cache_read_input_tokens': 3, @@ -1316,11 +1265,7 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: lambda: BetaRawMessageStartEvent( message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start' ), - snapshot( - Usage( - request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} - ) - ), + snapshot(RequestUsage(input_tokens=1, output_tokens=1, details={'input_tokens': 1, 'output_tokens': 1})), id='RawMessageStartEvent', ), pytest.param( @@ -1329,13 +1274,15 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta', ), - snapshot(Usage(response_tokens=5, total_tokens=5, details={'output_tokens': 5})), + snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})), id='RawMessageDeltaEvent', ), - pytest.param(lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(Usage()), id='RawMessageStopEvent'), + pytest.param( + lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(RequestUsage()), id='RawMessageStopEvent' + ), ], ) -def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: Usage): +def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: RunUsage): assert _map_usage(message_callback()) == usage @@ -1490,11 +1437,9 @@ async def test_anthropic_web_search_tool(allow_model_requests: None, anthropic_a content="Mount Lewotobi Laki Laki in Indonesia is experiencing its second consecutive day of eruption, sending volcanic materials and ash up to 18 km into the sky. This is one of Indonesia's largest eruptions since 2010, though fortunately no casualties have been reported." ), ], - usage=Usage( - requests=1, - request_tokens=14923, - response_tokens=317, - total_tokens=15240, + usage=RequestUsage( + input_tokens=14923, + output_tokens=317, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1504,7 +1449,7 @@ async def test_anthropic_web_search_tool(allow_model_requests: None, anthropic_a ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id='msg_01W2YfD2EF8BbAqLRr8ftH4W', ), ] ) @@ -1547,11 +1492,9 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop ), TextPart(content='3 * 12390 = 37,170'), ], - usage=Usage( - requests=1, - request_tokens=1630, - response_tokens=109, - total_tokens=1739, + usage=RequestUsage( + input_tokens=1630, + output_tokens=109, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1561,7 +1504,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop ), model_name='claude-sonnet-4-20250514', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id='msg_01RJnbK7VMxvS2SyvtyJAQVU', ), ] ) @@ -1607,15 +1550,10 @@ async def test_anthropic_server_tool_pass_history_to_another_provider( ModelRequest(parts=[UserPromptPart(content='What day is tomorrow?', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Tomorrow will be **Friday, August 15, 2025**.')], - usage=Usage( - request_tokens=458, - response_tokens=17, - total_tokens=475, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=458, output_tokens=17, details={'reasoning_tokens': 0}), model_name='gpt-4.1-2025-04-14', timestamp=IsDatetime(), - vendor_id='resp_689dc4abe31c81968ed493d15d8810fe0afe80ec3d42722e', + provider_request_id='resp_689dc4abe31c81968ed493d15d8810fe0afe80ec3d42722e', ), ] ) @@ -1718,11 +1656,9 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01X9wcHKKAZD9tBC711xipPa') ], - usage=Usage( - requests=1, - request_tokens=445, - response_tokens=23, - total_tokens=468, + usage=RequestUsage( + input_tokens=445, + output_tokens=23, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1732,7 +1668,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_012TXW181edhmR5JCsQRsBKx', + provider_request_id='msg_012TXW181edhmR5JCsQRsBKx', ), ModelRequest( parts=[ @@ -1752,11 +1688,9 @@ async def get_user_country() -> str: tool_call_id='toolu_01LZABsgreMefH2Go8D5PQbW', ) ], - usage=Usage( - requests=1, - request_tokens=497, - response_tokens=56, - total_tokens=553, + usage=RequestUsage( + input_tokens=497, + output_tokens=56, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1766,7 +1700,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01K4Fzcf1bhiyLzHpwLdrefj', + provider_request_id='msg_01K4Fzcf1bhiyLzHpwLdrefj', ), ModelRequest( parts=[ @@ -1818,11 +1752,9 @@ async def get_user_country() -> str: ), ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01JJ8TequDsrEU2pv1QFRWAK'), ], - usage=Usage( - requests=1, - request_tokens=383, - response_tokens=65, - total_tokens=448, + usage=RequestUsage( + input_tokens=383, + output_tokens=65, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1832,7 +1764,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01MsqUB7ZyhjGkvepS1tCXp3', + provider_request_id='msg_01MsqUB7ZyhjGkvepS1tCXp3', ), ModelRequest( parts=[ @@ -1850,11 +1782,9 @@ async def get_user_country() -> str: content='Based on the result, you are located in Mexico. The largest city in Mexico is Mexico City (Ciudad de México), which is both the capital and the most populous city in the country. With a population of approximately 9.2 million people in the city proper and over 21 million people in its metropolitan area, Mexico City is not only the largest city in Mexico but also one of the largest cities in the world.' ) ], - usage=Usage( - requests=1, - request_tokens=460, - response_tokens=91, - total_tokens=551, + usage=RequestUsage( + input_tokens=460, + output_tokens=91, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1864,7 +1794,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_0142umg4diSckrDtV9vAmmPL', + provider_request_id='msg_0142umg4diSckrDtV9vAmmPL', ), ] ) @@ -1909,11 +1839,9 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01ArHq5f2wxRpRF2PVQcKExM') ], - usage=Usage( - requests=1, - request_tokens=459, - response_tokens=38, - total_tokens=497, + usage=RequestUsage( + input_tokens=459, + output_tokens=38, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1923,7 +1851,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_018YiNXULHGpoKoHkTt6GivG', + provider_request_id='msg_018YiNXULHGpoKoHkTt6GivG', ), ModelRequest( parts=[ @@ -1944,11 +1872,9 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( - requests=1, - request_tokens=510, - response_tokens=17, - total_tokens=527, + usage=RequestUsage( + input_tokens=510, + output_tokens=17, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -1958,7 +1884,7 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01WiRVmLhCrJbJZRqmAWKv3X', + provider_request_id='msg_01WiRVmLhCrJbJZRqmAWKv3X', ), ] ) @@ -2003,11 +1929,9 @@ class CountryLanguage(BaseModel): content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' ) ], - usage=Usage( - requests=1, - request_tokens=265, - response_tokens=31, - total_tokens=296, + usage=RequestUsage( + input_tokens=265, + output_tokens=31, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, @@ -2017,7 +1941,7 @@ class CountryLanguage(BaseModel): ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01N2PwwVQo2aBtt6UFhMDtEX', + provider_request_id='msg_01N2PwwVQo2aBtt6UFhMDtEX', ), ] ) diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 5d076eaebb..b3e26776c9 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -34,7 +34,7 @@ ) from pydantic_ai.models import ModelRequestParameters from pydantic_ai.tools import ToolDefinition -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage, RunUsage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -58,7 +58,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro assert result.output == snapshot( "Hello! How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help." ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=7, response_tokens=30, total_tokens=37)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=7, output_tokens=30)) assert result.all_messages() == snapshot( [ ModelRequest( @@ -79,7 +79,7 @@ async def test_bedrock_model(allow_model_requests: None, bedrock_provider: Bedro content="Hello! How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help." ) ], - usage=Usage(requests=1, request_tokens=7, response_tokens=30, total_tokens=37), + usage=RequestUsage(input_tokens=7, output_tokens=30), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -111,7 +111,7 @@ async def temperature(city: str, date: datetime.date) -> str: result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response) assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=1236, response_tokens=298, total_tokens=1534)) + assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298)) assert result.all_messages() == snapshot( [ ModelRequest( @@ -137,7 +137,7 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id='tooluse_5WEci1UmQ8ifMFkUcy2gHQ', ), ], - usage=Usage(requests=1, request_tokens=551, response_tokens=132, total_tokens=683), + usage=RequestUsage(input_tokens=551, output_tokens=132), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -162,7 +162,7 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id='tooluse_9AjloJSaQDKmpPFff-2Clg', ), ], - usage=Usage(requests=1, request_tokens=685, response_tokens=166, total_tokens=851), + usage=RequestUsage(input_tokens=685, output_tokens=166), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -262,7 +262,7 @@ async def get_capital(country: str) -> str: tool_call_id='tooluse_F8LnaCMtQ0-chKTnPhNH2g', ), ], - usage=Usage(requests=1, request_tokens=417, response_tokens=69, total_tokens=486), + usage=RequestUsage(input_tokens=417, output_tokens=69), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -286,7 +286,7 @@ async def get_capital(country: str) -> str: """ ) ], - usage=Usage(requests=1, request_tokens=509, response_tokens=108, total_tokens=617), + usage=RequestUsage(input_tokens=509, output_tokens=108), model_name='us.amazon.nova-micro-v1:0', timestamp=IsDatetime(), ), @@ -553,7 +553,7 @@ def instructions() -> str: content='The capital of France is Paris. Paris is not only the political and economic hub of the country but also a major center for culture, fashion, art, and tourism. It is renowned for its rich history, iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum, as well as its influence on global culture and cuisine.' ) ], - usage=Usage(requests=1, request_tokens=13, response_tokens=71, total_tokens=84), + usage=RequestUsage(input_tokens=13, output_tokens=71), model_name='us.amazon.nova-pro-v1:0', timestamp=IsDatetime(), ), @@ -603,7 +603,7 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content=IsStr()), ThinkingPart(content=IsStr())], - usage=Usage(requests=1, request_tokens=12, response_tokens=882, total_tokens=894), + usage=RequestUsage(input_tokens=12, output_tokens=882), model_name='us.deepseek.r1-v1:0', timestamp=IsDatetime(), ), @@ -626,7 +626,7 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[IsInstance(TextPart), IsInstance(ThinkingPart)], - usage=Usage(requests=1, request_tokens=12, response_tokens=882, total_tokens=894), + usage=RequestUsage(input_tokens=12, output_tokens=882), model_name='us.deepseek.r1-v1:0', timestamp=IsDatetime(), ), @@ -646,7 +646,7 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p ), IsInstance(TextPart), ], - usage=Usage(requests=1, request_tokens=636, response_tokens=690, total_tokens=1326), + usage=RequestUsage(input_tokens=636, output_tokens=690), model_name='us.anthropic.claude-3-7-sonnet-20250219-v1:0', timestamp=IsDatetime(), ), diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index cce621325f..739397350e 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -25,7 +25,7 @@ UserPromptPart, ) from pydantic_ai.tools import RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage, RunUsage from ..conftest import IsDatetime, IsInstance, IsNow, raise_if_exception, try_import @@ -104,27 +104,25 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -150,11 +148,10 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.output == 'world' assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=1, - response_tokens=1, - total_tokens=2, + input_tokens=1, + output_tokens=1, details={ 'input_tokens': 1, 'output_tokens': 1, @@ -194,7 +191,6 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -281,7 +277,6 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, total_tokens=0, details={}), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -303,13 +298,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage( - requests=1, - request_tokens=5, - response_tokens=3, - total_tokens=8, - details={'input_tokens': 4, 'output_tokens': 2}, - ), + usage=RequestUsage(input_tokens=5, output_tokens=3, details={'input_tokens': 4, 'output_tokens': 2}), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), @@ -325,18 +314,16 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), model_name='command-r7b-12-2024', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=3, - request_tokens=5, - response_tokens=3, - total_tokens=8, + input_tokens=5, + output_tokens=3, details={'input_tokens': 4, 'output_tokens': 2}, ) ) @@ -403,12 +390,8 @@ def simple_instructions(ctx: RunContext): content="The capital of France is Paris. It is the country's largest city and serves as the economic, cultural, and political center of France. Paris is known for its rich history, iconic landmarks such as the Eiffel Tower and the Louvre Museum, and its significant influence on fashion, cuisine, and the arts." ) ], - usage=Usage( - requests=1, - request_tokens=542, - response_tokens=63, - total_tokens=605, - details={'input_tokens': 13, 'output_tokens': 61}, + usage=RequestUsage( + input_tokens=542, output_tokens=63, details={'input_tokens': 13, 'output_tokens': 61} ), model_name='command-r7b-12-2024', timestamp=IsDatetime(), @@ -449,15 +432,10 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1909, - total_tokens=1922, - details={'reasoning_tokens': 1472, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=1909, details={'reasoning_tokens': 1472}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', + provider_request_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', ), ] ) @@ -478,15 +456,10 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1909, - total_tokens=1922, - details={'reasoning_tokens': 1472, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=1909, details={'reasoning_tokens': 1472}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', + provider_request_id='resp_680739f4ad748191bd11096967c37c8b048efc3f8b2a068e', ), ModelRequest( parts=[ @@ -498,12 +471,8 @@ async def test_cohere_model_thinking_part(allow_model_requests: None, co_api_key ), ModelResponse( parts=[IsInstance(TextPart)], - usage=Usage( - requests=1, - request_tokens=1457, - response_tokens=807, - total_tokens=2264, - details={'input_tokens': 954, 'output_tokens': 805}, + usage=RequestUsage( + input_tokens=1457, output_tokens=807, details={'input_tokens': 954, 'output_tokens': 805} ), model_name='command-r7b-12-2024', timestamp=IsDatetime(), diff --git a/tests/models/test_deepseek.py b/tests/models/test_deepseek.py index 64da89d097..fa2779c434 100644 --- a/tests/models/test_deepseek.py +++ b/tests/models/test_deepseek.py @@ -19,7 +19,7 @@ ThinkingPartDelta, UserPromptPart, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsStr, try_import @@ -44,21 +44,18 @@ async def test_deepseek_model_thinking_part(allow_model_requests: None, deepseek ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), ModelResponse( parts=[ThinkingPart(content=IsStr()), TextPart(content=IsStr())], - usage=Usage( - requests=1, - request_tokens=12, - response_tokens=789, - total_tokens=801, + usage=RequestUsage( + input_tokens=12, + output_tokens=789, details={ 'prompt_cache_hit_tokens': 0, 'prompt_cache_miss_tokens': 12, 'reasoning_tokens': 415, - 'cached_tokens': 0, }, ), model_name='deepseek-reasoner', timestamp=IsDatetime(), - vendor_id='181d9669-2b3a-445e-bd13-2ebff2c378f6', + provider_request_id='181d9669-2b3a-445e-bd13-2ebff2c378f6', ), ] ) diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 4e66ce6726..6074d174c0 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -16,7 +16,7 @@ from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.settings import ModelSettings -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ..conftest import IsNow, try_import @@ -65,7 +65,7 @@ def test_first_successful() -> None: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), ), @@ -90,7 +90,7 @@ def test_first_failed() -> None: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), ), @@ -116,7 +116,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), ), @@ -175,19 +175,19 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None [ ModelResponse( parts=[TextPart(content='hello ')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), @@ -360,19 +360,19 @@ async def test_first_success_streaming() -> None: [ ModelResponse( parts=[TextPart(content='hello ')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), @@ -389,19 +389,19 @@ async def test_first_failed_streaming() -> None: [ ModelResponse( parts=[TextPart(content='hello ')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), ), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 6f0a414a5c..73de9b971f 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -46,6 +46,7 @@ _GeminiFunctionCall, _GeminiFunctionCallingConfig, _GeminiFunctionCallPart, + _GeminiModalityTokenCount, _GeminiResponse, _GeminiSafetyRating, _GeminiTextPart, @@ -53,11 +54,13 @@ _GeminiToolConfig, _GeminiTools, _GeminiUsageMetaData, + _metadata_as_usage, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.usage import RequestUsage from ..conftest import ClientWithHandler, IsDatetime, IsInstance, IsNow, IsStr, TestEnv, try_import @@ -619,14 +622,14 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) result = await agent.run('Hello', message_history=result.new_messages()) assert result.output == 'Hello world' @@ -635,18 +638,18 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -667,10 +670,10 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2, 123]}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -730,10 +733,10 @@ async def get_location(loc_name: str) -> str: parts=[ ToolCallPart(tool_name='get_location', args={'loc_name': 'San Fransisco'}, tool_call_id=IsStr()) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -750,10 +753,10 @@ async def get_location(loc_name: str) -> str: ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), ToolCallPart(tool_name='get_location', args={'loc_name': 'New York'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -773,14 +776,14 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) - assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6)) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -821,12 +824,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): 'Hello world', ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): @@ -858,7 +861,7 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot(['abc', 'abc€def', 'abc€def']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -888,7 +891,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot([(1, 2), (1, 2)]) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): @@ -929,7 +932,7 @@ async def bar(y: str) -> str: async with agent.run_stream('Hello') as result: response = await result.get_output() assert response == snapshot((1, 2)) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=2, output_tokens=4)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), @@ -938,7 +941,7 @@ async def bar(y: str) -> str: ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()), ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()), ], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -954,7 +957,7 @@ async def bar(y: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -1022,7 +1025,7 @@ def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='gemini-1.5-flash', timestamp=IsDatetime(), ), @@ -1223,16 +1226,12 @@ async def get_image() -> BinaryContent: ), ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()), ], - usage=Usage( - requests=1, - request_tokens=38, - response_tokens=28, - total_tokens=427, - details={'thoughts_tokens': 361, 'text_prompt_tokens': 38}, + usage=RequestUsage( + input_tokens=38, output_tokens=28, details={'thoughts_tokens': 361, 'text_prompt_tokens': 38} ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1253,16 +1252,14 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The image shows a kiwi fruit, sliced in half.')], - usage=Usage( - requests=1, - request_tokens=360, - response_tokens=11, - total_tokens=572, + usage=RequestUsage( + input_tokens=360, + output_tokens=11, details={'thoughts_tokens': 201, 'text_prompt_tokens': 102, 'image_prompt_tokens': 258}, ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1382,16 +1379,12 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.\n')], - usage=Usage( - requests=1, - request_tokens=13, - response_tokens=8, - total_tokens=21, - details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, + usage=RequestUsage( + input_tokens=13, output_tokens=8, details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8} ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1491,15 +1484,10 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api """ ), ], - usage=Usage( - request_tokens=13, - response_tokens=2028, - total_tokens=2041, - details={'reasoning_tokens': 1664, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=2028, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', + provider_request_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', ), ] ) @@ -1522,15 +1510,10 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=2028, - total_tokens=2041, - details={'reasoning_tokens': 1664, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=2028, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', + provider_request_id='resp_680393ff82488191a7d0850bf0dd99a004f0817ea037a07b', ), ModelRequest( parts=[ @@ -1581,16 +1564,12 @@ async def test_gemini_model_thinking_part(allow_model_requests: None, gemini_api """ ), ], - usage=Usage( - requests=1, - request_tokens=801, - response_tokens=1519, - total_tokens=2320, - details={'thoughts_tokens': 794, 'text_prompt_tokens': 801}, + usage=RequestUsage( + input_tokens=801, output_tokens=1519, details={'thoughts_tokens': 794, 'text_prompt_tokens': 801} ), model_name='gemini-2.5-flash-preview-04-17', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1620,11 +1599,9 @@ async def test_gemini_youtube_video_url_input(allow_model_requests: None, gemini content='The main content of the URL is an analysis of recent 404 HTTP responses. The analysis identifies several patterns, including the most common endpoints with 404 errors, request patterns (such as all requests being GET requests), timeline-related issues, and configuration/authentication problems. The analysis also provides recommendations for addressing the 404 errors.' ) ], - usage=Usage( - requests=1, - request_tokens=9, - response_tokens=72, - total_tokens=81, + usage=RequestUsage( + input_tokens=9, + output_tokens=72, details={ 'text_prompt_tokens': 9, 'video_prompt_tokens': 0, @@ -1634,7 +1611,7 @@ async def test_gemini_youtube_video_url_input(allow_model_requests: None, gemini ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1652,7 +1629,7 @@ async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient): for message in result.all_messages(): if isinstance(message, ModelResponse): - assert message.vendor_details is None + assert message.provider_details is None async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): @@ -1672,7 +1649,7 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): result = await agent.run('Test with thought') assert result.output == 'Hello from thought test' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) @pytest.mark.vcr() @@ -1700,17 +1677,13 @@ async def bar() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='bar', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=21, - response_tokens=1, - total_tokens=22, - details={'text_candidates_tokens': 1, 'text_prompt_tokens': 21}, + usage=RequestUsage( + input_tokens=21, output_tokens=1, details={'text_prompt_tokens': 21, 'text_candidates_tokens': 1} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1730,17 +1703,13 @@ async def bar() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=27, - response_tokens=5, - total_tokens=32, - details={'text_candidates_tokens': 5, 'text_prompt_tokens': 27}, + usage=RequestUsage( + input_tokens=27, output_tokens=5, details={'text_prompt_tokens': 27, 'text_candidates_tokens': 5} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1785,17 +1754,13 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=32, - response_tokens=5, - total_tokens=37, - details={'text_prompt_tokens': 32, 'text_candidates_tokens': 5}, + usage=RequestUsage( + input_tokens=32, output_tokens=5, details={'text_prompt_tokens': 32, 'text_candidates_tokens': 5} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1815,17 +1780,13 @@ async def get_user_country() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=46, - response_tokens=8, - total_tokens=54, - details={'text_prompt_tokens': 46, 'text_candidates_tokens': 8}, + usage=RequestUsage( + input_tokens=46, output_tokens=8, details={'text_prompt_tokens': 46, 'text_candidates_tokens': 8} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -1877,17 +1838,13 @@ def upcase(text: str) -> str: """ ) ], - usage=Usage( - requests=1, - request_tokens=9, - response_tokens=44, - total_tokens=598, - details={'thoughts_tokens': 545, 'text_prompt_tokens': 9}, + usage=RequestUsage( + input_tokens=9, output_tokens=44, details={'thoughts_tokens': 545, 'text_prompt_tokens': 9} ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id='TT9IaNfGN_DmqtsPzKnE4AE', + provider_details={'finish_reason': 'STOP'}, + provider_request_id='TT9IaNfGN_DmqtsPzKnE4AE', ), ] ) @@ -1947,17 +1904,13 @@ class CityLocation(BaseModel): """ ) ], - usage=Usage( - requests=1, - request_tokens=17, - response_tokens=20, - total_tokens=37, - details={'text_prompt_tokens': 17, 'text_candidates_tokens': 20}, + usage=RequestUsage( + input_tokens=17, output_tokens=20, details={'text_prompt_tokens': 17, 'text_candidates_tokens': 20} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2006,17 +1959,13 @@ class CountryLanguage(BaseModel): """ ) ], - usage=Usage( - requests=1, - request_tokens=46, - response_tokens=46, - total_tokens=92, - details={'text_prompt_tokens': 46, 'text_candidates_tokens': 46}, + usage=RequestUsage( + input_tokens=46, output_tokens=46, details={'text_prompt_tokens': 46, 'text_candidates_tokens': 46} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2058,17 +2007,13 @@ class CityLocation(BaseModel): content='{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' ) ], - usage=Usage( - requests=1, - request_tokens=80, - response_tokens=56, - total_tokens=136, - details={'text_prompt_tokens': 80, 'text_candidates_tokens': 56}, + usage=RequestUsage( + input_tokens=80, output_tokens=56, details={'text_prompt_tokens': 80, 'text_candidates_tokens': 56} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2112,17 +2057,13 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=123, - response_tokens=12, - total_tokens=453, - details={'thoughts_tokens': 318, 'text_prompt_tokens': 123}, + usage=RequestUsage( + input_tokens=123, output_tokens=12, details={'thoughts_tokens': 318, 'text_prompt_tokens': 123} ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ModelRequest( parts=[ @@ -2143,17 +2084,13 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( - requests=1, - request_tokens=154, - response_tokens=13, - total_tokens=261, - details={'thoughts_tokens': 94, 'text_prompt_tokens': 154}, + usage=RequestUsage( + input_tokens=154, output_tokens=13, details={'thoughts_tokens': 94, 'text_prompt_tokens': 154} ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -2199,17 +2136,59 @@ class CountryLanguage(BaseModel): content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' ) ], - usage=Usage( - requests=1, - request_tokens=253, - response_tokens=27, - total_tokens=280, + usage=RequestUsage( + input_tokens=253, + output_tokens=27, details={'text_prompt_tokens': 253, 'text_candidates_tokens': 27}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) + + +def test_map_usage(): + response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')]))) + assert 'usage_metadata' in response + response['usage_metadata']['cached_content_token_count'] = 9100 + response['usage_metadata']['prompt_tokens_details'] = [ + _GeminiModalityTokenCount(modality='AUDIO', token_count=9200) + ] + response['usage_metadata']['cache_tokens_details'] = [ + _GeminiModalityTokenCount(modality='AUDIO', token_count=9300), + ] + response['usage_metadata']['candidates_tokens_details'] = [ + _GeminiModalityTokenCount(modality='AUDIO', token_count=9400) + ] + response['usage_metadata']['thoughts_token_count'] = 9500 + response['usage_metadata']['tool_use_prompt_token_count'] = 9600 + + assert _metadata_as_usage(response) == snapshot( + RequestUsage( + input_tokens=1, + cache_read_tokens=9100, + output_tokens=2, + input_audio_tokens=9200, + cache_audio_read_tokens=9300, + output_audio_tokens=9400, + details={ + 'cached_content_tokens': 9100, + 'audio_prompt_tokens': 9200, + 'audio_cache_tokens': 9300, + 'audio_candidates_tokens': 9400, + 'thoughts_tokens': 9500, + 'tool_use_prompt_tokens': 9600, + }, + ) + ) + + +def test_map_empty_usage(): + response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')]))) + assert 'usage_metadata' in response + del response['usage_metadata'] + + assert _metadata_as_usage(response) == RequestUsage() diff --git a/tests/models/test_gemini_vertex.py b/tests/models/test_gemini_vertex.py index 50b0bc09f1..5760faef7e 100644 --- a/tests/models/test_gemini_vertex.py +++ b/tests/models/test_gemini_vertex.py @@ -20,7 +20,7 @@ VideoUrl, ) from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -145,11 +145,11 @@ async def test_url_input( ), ModelResponse( parts=[TextPart(content=Is(expected_output))], - usage=IsInstance(Usage), + usage=IsInstance(RequestUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -182,11 +182,11 @@ async def test_url_input_force_download(allow_model_requests: None) -> None: # ), ModelResponse( parts=[TextPart(content=Is(output))], - usage=IsInstance(Usage), + usage=IsInstance(RequestUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 066566ee71..d94703e699 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -40,16 +40,26 @@ VideoUrl, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput -from pydantic_ai.result import Usage, UsageLimits from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits from ..conftest import IsDatetime, IsInstance, IsStr, try_import from ..parts_from_messages import part_types_from_messages with try_import() as imports_successful: - from google.genai.types import CodeExecutionResult, HarmBlockThreshold, HarmCategory, Language, Outcome + from google.genai.types import ( + CodeExecutionResult, + GenerateContentResponse, + GenerateContentResponseUsageMetadata, + HarmBlockThreshold, + HarmCategory, + Language, + MediaModality, + ModalityTokenCount, + Outcome, + ) - from pydantic_ai.models.google import GoogleModel, GoogleModelSettings + from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, _metadata_as_usage # type: ignore from pydantic_ai.providers.google import GoogleProvider pytestmark = [ @@ -73,11 +83,10 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP result = await agent.run('Hello!') assert result.output == snapshot('Hello there! How can I help you today?\n') assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=7, - response_tokens=11, - total_tokens=18, + input_tokens=7, + output_tokens=11, details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, ) ) @@ -97,16 +106,12 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP ), ModelResponse( parts=[TextPart(content='Hello there! How can I help you today?\n')], - usage=Usage( - requests=1, - request_tokens=7, - response_tokens=11, - total_tokens=18, - details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, + usage=RequestUsage( + input_tokens=7, output_tokens=11, details={'text_candidates_tokens': 11, 'text_prompt_tokens': 7} ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -137,11 +142,10 @@ async def temperature(city: str, date: datetime.date) -> str: result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response) assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=224, - response_tokens=35, - total_tokens=259, + input_tokens=224, + output_tokens=35, details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35}, ) ) @@ -165,16 +169,14 @@ async def temperature(city: str, date: datetime.date) -> str: tool_name='temperature', args={'date': '2022-01-01', 'city': 'London'}, tool_call_id=IsStr() ) ], - usage=Usage( - requests=1, - request_tokens=101, - response_tokens=14, - total_tokens=115, - details={'text_prompt_tokens': 101, 'text_candidates_tokens': 14}, + usage=RequestUsage( + input_tokens=101, + output_tokens=14, + details={'text_candidates_tokens': 14, 'text_prompt_tokens': 101}, ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -191,16 +193,14 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=123, - response_tokens=21, - total_tokens=144, - details={'text_prompt_tokens': 123, 'text_candidates_tokens': 21}, + usage=RequestUsage( + input_tokens=123, + output_tokens=21, + details={'text_candidates_tokens': 21, 'text_prompt_tokens': 123}, ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -250,16 +250,12 @@ async def get_capital(country: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_capital', args={'country': 'France'}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=57, - response_tokens=15, - total_tokens=227, - details={'thoughts_tokens': 155, 'text_prompt_tokens': 57}, + usage=RequestUsage( + input_tokens=57, output_tokens=15, details={'thoughts_tokens': 155, 'text_prompt_tokens': 57} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -277,16 +273,12 @@ async def get_capital(country: str) -> str: content='I am sorry, I cannot fulfill this request. The country "France" is not supported by my system.' ) ], - usage=Usage( - requests=1, - request_tokens=104, - response_tokens=22, - total_tokens=304, - details={'thoughts_tokens': 178, 'text_prompt_tokens': 104}, + usage=RequestUsage( + input_tokens=104, output_tokens=22, details={'thoughts_tokens': 178, 'text_prompt_tokens': 104} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -554,16 +546,12 @@ def instructions() -> str: ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.\n')], - usage=Usage( - requests=1, - request_tokens=13, - response_tokens=8, - total_tokens=21, - details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, + usage=RequestUsage( + input_tokens=13, output_tokens=8, details={'text_candidates_tokens': 8, 'text_prompt_tokens': 13} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -657,11 +645,9 @@ async def test_google_model_code_execution_tool(allow_model_requests: None, goog ), TextPart(content='Today is Thursday in Utrecht.\n'), ], - usage=Usage( - requests=1, - request_tokens=13, - response_tokens=95, - total_tokens=209, + usage=RequestUsage( + input_tokens=13, + output_tokens=95, details={ 'tool_use_prompt_tokens': 101, 'text_candidates_tokens': 95, @@ -671,7 +657,7 @@ async def test_google_model_code_execution_tool(allow_model_requests: None, goog ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -717,11 +703,9 @@ async def test_google_model_code_execution_tool(allow_model_requests: None, goog ), TextPart(content='Today is Thursday in Utrecht.\n'), ], - usage=Usage( - requests=1, - request_tokens=13, - response_tokens=95, - total_tokens=209, + usage=RequestUsage( + input_tokens=13, + output_tokens=95, details={ 'tool_use_prompt_tokens': 101, 'text_candidates_tokens': 95, @@ -731,7 +715,7 @@ async def test_google_model_code_execution_tool(allow_model_requests: None, goog ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest(parts=[UserPromptPart(content='What day is tomorrow?', timestamp=IsDatetime())]), ModelResponse( @@ -767,11 +751,9 @@ async def test_google_model_code_execution_tool(allow_model_requests: None, goog ), TextPart(content='Tomorrow is Friday.\n'), ], - usage=Usage( - requests=1, - request_tokens=113, - response_tokens=95, - total_tokens=411, + usage=RequestUsage( + input_tokens=113, + output_tokens=95, details={ 'tool_use_prompt_tokens': 203, 'text_candidates_tokens': 95, @@ -781,7 +763,7 @@ async def test_google_model_code_execution_tool(allow_model_requests: None, goog ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -929,16 +911,12 @@ async def test_google_model_thinking_part(allow_model_requests: None, google_pro ), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage( - requests=1, - request_tokens=15, - response_tokens=1041, - total_tokens=2703, - details={'thoughts_tokens': 1647, 'text_prompt_tokens': 15}, + usage=RequestUsage( + input_tokens=15, output_tokens=1041, details={'thoughts_tokens': 1647, 'text_prompt_tokens': 15} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1087,11 +1065,11 @@ async def test_google_url_input( ), ModelResponse( parts=[TextPart(content=Is(expected_output))], - usage=IsInstance(Usage), + usage=IsInstance(RequestUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -1124,11 +1102,11 @@ async def test_google_url_input_force_download(allow_model_requests: None) -> No ), ModelResponse( parts=[TextPart(content=Is(output))], - usage=IsInstance(Usage), + usage=IsInstance(RequestUsage), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id=IsStr(), + provider_details={'finish_reason': 'STOP'}, + provider_request_id=IsStr(), ), ] ) @@ -1170,16 +1148,12 @@ async def bar() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='bar', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=21, - response_tokens=1, - total_tokens=22, - details={'text_candidates_tokens': 1, 'text_prompt_tokens': 21}, + usage=RequestUsage( + input_tokens=21, output_tokens=1, details={'text_candidates_tokens': 1, 'text_prompt_tokens': 21} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1199,16 +1173,12 @@ async def bar() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=27, - response_tokens=5, - total_tokens=32, - details={'text_candidates_tokens': 5, 'text_prompt_tokens': 27}, + usage=RequestUsage( + input_tokens=27, output_tokens=5, details={'text_candidates_tokens': 5, 'text_prompt_tokens': 27} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1263,16 +1233,12 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=33, - response_tokens=5, - total_tokens=38, - details={'text_candidates_tokens': 5, 'text_prompt_tokens': 33}, + usage=RequestUsage( + input_tokens=33, output_tokens=5, details={'text_candidates_tokens': 5, 'text_prompt_tokens': 33} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1292,16 +1258,12 @@ async def get_user_country() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=47, - response_tokens=8, - total_tokens=55, - details={'text_candidates_tokens': 8, 'text_prompt_tokens': 47}, + usage=RequestUsage( + input_tokens=47, output_tokens=8, details={'text_candidates_tokens': 8, 'text_prompt_tokens': 47} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1346,16 +1308,12 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=49, - response_tokens=12, - total_tokens=325, - details={'thoughts_tokens': 264, 'text_prompt_tokens': 49}, + usage=RequestUsage( + input_tokens=49, output_tokens=12, details={'thoughts_tokens': 264, 'text_prompt_tokens': 49} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1369,16 +1327,12 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=Usage( - requests=1, - request_tokens=80, - response_tokens=9, - total_tokens=239, - details={'thoughts_tokens': 150, 'text_prompt_tokens': 80}, + usage=RequestUsage( + input_tokens=80, output_tokens=9, details={'thoughts_tokens': 150, 'text_prompt_tokens': 80} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1436,16 +1390,12 @@ class CityLocation(BaseModel): """ ) ], - usage=Usage( - requests=1, - request_tokens=25, - response_tokens=20, - total_tokens=45, - details={'text_candidates_tokens': 20, 'text_prompt_tokens': 25}, + usage=RequestUsage( + input_tokens=25, output_tokens=20, details={'text_candidates_tokens': 20, 'text_prompt_tokens': 25} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1493,16 +1443,12 @@ class CountryLanguage(BaseModel): """ ) ], - usage=Usage( - requests=1, - request_tokens=50, - response_tokens=46, - total_tokens=96, - details={'text_candidates_tokens': 46, 'text_prompt_tokens': 50}, + usage=RequestUsage( + input_tokens=50, output_tokens=46, details={'text_candidates_tokens': 46, 'text_prompt_tokens': 50} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1539,16 +1485,12 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( - requests=1, - request_tokens=80, - response_tokens=13, - total_tokens=93, - details={'text_candidates_tokens': 13, 'text_prompt_tokens': 80}, + usage=RequestUsage( + input_tokens=80, output_tokens=13, details={'text_candidates_tokens': 13, 'text_prompt_tokens': 80} ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1591,16 +1533,12 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=123, - response_tokens=12, - total_tokens=267, - details={'thoughts_tokens': 132, 'text_prompt_tokens': 123}, + usage=RequestUsage( + input_tokens=123, output_tokens=12, details={'thoughts_tokens': 132, 'text_prompt_tokens': 123} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -1621,16 +1559,12 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage( - requests=1, - request_tokens=154, - response_tokens=13, - total_tokens=320, - details={'thoughts_tokens': 153, 'text_prompt_tokens': 154}, + usage=RequestUsage( + input_tokens=154, output_tokens=13, details={'thoughts_tokens': 153, 'text_prompt_tokens': 154} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1675,16 +1609,14 @@ class CountryLanguage(BaseModel): content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' ) ], - usage=Usage( - requests=1, - request_tokens=240, - response_tokens=27, - total_tokens=267, + usage=RequestUsage( + input_tokens=240, + output_tokens=27, details={'text_candidates_tokens': 27, 'text_prompt_tokens': 240}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -1695,11 +1627,12 @@ async def test_google_model_usage_limit_exceeded(allow_model_requests: None, goo agent = Agent(model=model) with pytest.raises( - UsageLimitExceeded, match='The next request would exceed the request_tokens_limit of 9 \\(request_tokens=12\\)' + UsageLimitExceeded, + match='The next request would exceed the input_tokens_limit of 9 \\(input_tokens=12\\)', ): await agent.run( 'The quick brown fox jumps over the lazydog.', - usage_limits=UsageLimits(request_tokens_limit=9, count_tokens_before_request=True), + usage_limits=UsageLimits(input_tokens_limit=9, count_tokens_before_request=True), ) @@ -1709,7 +1642,7 @@ async def test_google_model_usage_limit_not_exceeded(allow_model_requests: None, result = await agent.run( 'The quick brown fox jumps over the lazydog.', - usage_limits=UsageLimits(request_tokens_limit=15, count_tokens_before_request=True), + usage_limits=UsageLimits(input_tokens_limit=15, count_tokens_before_request=True), ) assert result.output == snapshot("""\ That's a classic! It's famously known as a **pangram**, which means it's a sentence that contains every letter of the alphabet. @@ -1740,3 +1673,36 @@ async def get_user_country() -> str: 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True), ) + + +def test_map_usage(): + assert _metadata_as_usage(GenerateContentResponse()) == RequestUsage() + + response = GenerateContentResponse( + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=1, + candidates_token_count=2, + cached_content_token_count=9100, + thoughts_token_count=9500, + prompt_tokens_details=[ModalityTokenCount(modality=MediaModality.AUDIO, token_count=9200)], + cache_tokens_details=[ModalityTokenCount(modality=MediaModality.AUDIO, token_count=9300)], + candidates_tokens_details=[ModalityTokenCount(modality=MediaModality.AUDIO, token_count=9400)], + ) + ) + assert _metadata_as_usage(response) == snapshot( + RequestUsage( + input_tokens=1, + cache_read_tokens=9100, + output_tokens=2, + input_audio_tokens=9200, + cache_audio_read_tokens=9300, + output_audio_tokens=9400, + details={ + 'cached_content_tokens': 9100, + 'thoughts_tokens': 9500, + 'audio_prompt_tokens': 9200, + 'audio_cache_tokens': 9300, + 'audio_candidates_tokens': 9400, + }, + ) + ) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 43b2ac2df5..f8b123afeb 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -37,7 +37,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage, RunUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream @@ -142,31 +142,29 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -216,10 +214,9 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -304,10 +301,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RequestUsage(input_tokens=2, output_tokens=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -327,10 +324,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), + usage=RequestUsage(input_tokens=3, output_tokens=2), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -344,10 +341,9 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -452,7 +448,7 @@ async def test_stream_structured(allow_model_requests: None): ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), @@ -579,10 +575,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_wkpd')], - usage=Usage(requests=1, request_tokens=192, response_tokens=8, total_tokens=200), + usage=RequestUsage(input_tokens=192, output_tokens=8), model_name='meta-llama/llama-4-scout-17b-16e-instruct', timestamp=IsDatetime(), - vendor_id='chatcmpl-3c327c89-e9f5-4aac-a5d5-190e6f6f25c9', + provider_request_id='chatcmpl-3c327c89-e9f5-4aac-a5d5-190e6f6f25c9', ), ModelRequest( parts=[ @@ -603,10 +599,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The fruit in the image is a kiwi.')], - usage=Usage(requests=1, request_tokens=2552, response_tokens=11, total_tokens=2563), + usage=RequestUsage(input_tokens=2552, output_tokens=11), model_name='meta-llama/llama-4-scout-17b-16e-instruct', timestamp=IsDatetime(), - vendor_id='chatcmpl-82dfad42-6a28-4089-82c3-c8633f626c0d', + provider_request_id='chatcmpl-82dfad42-6a28-4089-82c3-c8633f626c0d', ), ] ) @@ -681,10 +677,10 @@ async def test_groq_model_instructions(allow_model_requests: None, groq_api_key: ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage(requests=1, request_tokens=48, response_tokens=8, total_tokens=56), + usage=RequestUsage(input_tokens=48, output_tokens=8), model_name='llama-3.3-70b-versatile', timestamp=IsDatetime(), - vendor_id='chatcmpl-7586b6a9-fb4b-4ec7-86a0-59f0a77844cf', + provider_request_id='chatcmpl-7586b6a9-fb4b-4ec7-86a0-59f0a77844cf', ), ] ) @@ -883,10 +879,10 @@ async def test_groq_model_web_search_tool(allow_model_requests: None, groq_api_k ), TextPart(content='The current day is Tuesday.'), ], - usage=Usage(requests=1, request_tokens=4287, response_tokens=117, total_tokens=4404), + usage=RequestUsage(input_tokens=4287, output_tokens=117), model_name='compound-beta', timestamp=IsDatetime(), - vendor_id='stub', + provider_request_id='stub', ), ] ) @@ -906,10 +902,10 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key ), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage(requests=1, request_tokens=21, response_tokens=1414, total_tokens=1435), + usage=RequestUsage(input_tokens=21, output_tokens=1414), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - vendor_id=IsStr(), + provider_request_id=IsStr(), ), ] ) @@ -927,10 +923,10 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key ), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage(requests=1, request_tokens=21, response_tokens=1414, total_tokens=1435), + usage=RequestUsage(input_tokens=21, output_tokens=1414), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - vendor_id='chatcmpl-9748c1af-1065-410a-969a-d7fb48039fbb', + provider_request_id='chatcmpl-9748c1af-1065-410a-969a-d7fb48039fbb', ), ModelRequest( parts=[ @@ -943,10 +939,10 @@ async def test_groq_model_thinking_part(allow_model_requests: None, groq_api_key ), ModelResponse( parts=[IsInstance(ThinkingPart), IsInstance(TextPart)], - usage=Usage(requests=1, request_tokens=524, response_tokens=1590, total_tokens=2114), + usage=RequestUsage(input_tokens=524, output_tokens=1590), model_name='deepseek-r1-distill-llama-70b', timestamp=IsDatetime(), - vendor_id='chatcmpl-994aa228-883a-498c-8b20-9655d770b697', + provider_request_id='chatcmpl-994aa228-883a-498c-8b20-9655d770b697', ), ] ) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index f16c1420e9..5a77d58cf5 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -36,9 +36,10 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import RunContext +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream @@ -160,16 +161,18 @@ async def test_simple_completion(allow_model_requests: None, huggingface_api_key request = messages[0] response = messages[1] assert request.parts[0].content == 'hello' # type: ignore - assert response == ModelResponse( - parts=[ - TextPart( - content='Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' - ) - ], - usage=Usage(requests=1, request_tokens=30, response_tokens=29, total_tokens=59), - model_name='Qwen/Qwen2.5-72B-Instruct-fast', - timestamp=datetime(2025, 7, 8, 13, 42, 33, tzinfo=timezone.utc), - vendor_id='chatcmpl-d445c0d473a84791af2acf356cc00df7', + assert response == snapshot( + ModelResponse( + parts=[ + TextPart( + content='Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) + ], + usage=RequestUsage(input_tokens=30, output_tokens=29), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=datetime(2025, 7, 8, 13, 42, 33, tzinfo=timezone.utc), + provider_request_id='chatcmpl-d445c0d473a84791af2acf356cc00df7', + ) ) @@ -186,7 +189,7 @@ async def test_request_simple_usage(allow_model_requests: None, huggingface_api_ result.output == "Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, or just want to chat, feel free to let me know!" ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=30, response_tokens=40, total_tokens=70)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=30, output_tokens=40)) async def test_request_structured_response( @@ -224,18 +227,19 @@ async def test_request_structured_response( assert result.output == [1, 2, 123] messages = result.all_messages() assert messages[0].parts[0].content == 'Hello' # type: ignore - assert messages[1] == ModelResponse( - parts=[ - ToolCallPart( - tool_name='final_result', - args='{"response": [1, 2, 123]}', - tool_call_id='123', - ) - ], - usage=Usage(requests=1), - model_name='hf-model', - timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + assert messages[1] == snapshot( + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"response": [1, 2, 123]}', + tool_call_id='123', + ) + ], + model_name='hf-model', + timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + provider_request_id='123', + ) ) @@ -363,10 +367,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -386,10 +390,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RequestUsage(input_tokens=2, output_tokens=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -403,10 +407,9 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -447,7 +450,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -557,9 +560,9 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=20, output_tokens=10)) # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_structured_finish_reason(allow_model_requests: None): @@ -616,7 +619,7 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3)) @pytest.mark.vcr() @@ -650,10 +653,10 @@ async def test_image_url_input(allow_model_requests: None, huggingface_api_key: ), ModelResponse( parts=[TextPart(content='Hello! How can I assist you with this image of a potato?')], - usage=Usage(requests=1, request_tokens=269, response_tokens=15, total_tokens=284), + usage=RequestUsage(input_tokens=269, output_tokens=15), model_name='Qwen/Qwen2.5-VL-72B-Instruct', timestamp=datetime(2025, 7, 8, 14, 4, 39, tzinfo=timezone.utc), - vendor_id='chatcmpl-49aa100effab4ca28514d5ccc00d7944', + provider_request_id='chatcmpl-49aa100effab4ca28514d5ccc00d7944', ), ] ) @@ -716,10 +719,10 @@ def simple_instructions(ctx: RunContext): ), ModelResponse( parts=[TextPart(content='Paris')], - usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), + usage=RequestUsage(input_tokens=26, output_tokens=2), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=IsDatetime(), - vendor_id='chatcmpl-b3936940372c481b8d886e596dc75524', + provider_request_id='chatcmpl-b3936940372c481b8d886e596dc75524', ), ] ) @@ -808,10 +811,9 @@ def response_validator(value: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='invalid-response')], - usage=Usage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -825,10 +827,9 @@ def response_validator(value: str) -> str: ), ModelResponse( parts=[TextPart(content='final-response')], - usage=Usage(requests=1), model_name='hf-model', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -955,10 +956,10 @@ async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_ap IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + usage=RequestUsage(input_tokens=15, output_tokens=1090), model_name='Qwen/Qwen3-235B-A22B', timestamp=IsDatetime(), - vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + provider_request_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', ), ] ) @@ -978,10 +979,10 @@ async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_ap IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + usage=RequestUsage(input_tokens=15, output_tokens=1090), model_name='Qwen/Qwen3-235B-A22B', timestamp=IsDatetime(), - vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + provider_request_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', ), ModelRequest( parts=[ @@ -996,10 +997,10 @@ async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_ap IsInstance(ThinkingPart), TextPart(content=IsStr()), ], - usage=Usage(requests=1, request_tokens=691, response_tokens=1860, total_tokens=2551), + usage=RequestUsage(input_tokens=691, output_tokens=1860), model_name='Qwen/Qwen3-235B-A22B', timestamp=IsDatetime(), - vendor_id='chatcmpl-35fdec1307634f94a39f7e26f52e12a7', + provider_request_id='chatcmpl-35fdec1307634f94a39f7e26f52e12a7', ), ] ) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 1f321c4533..b5704c7b94 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -37,7 +37,7 @@ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel from pydantic_ai.settings import ModelSettings -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ..conftest import IsStr, try_import @@ -82,7 +82,7 @@ async def request( TextPart('text2'), {}, # test unexpected parts # type: ignore ], - usage=Usage(request_tokens=100, response_tokens=200), + usage=RequestUsage(input_tokens=100, output_tokens=200), model_name='my_model_123', ) @@ -99,7 +99,7 @@ async def request_stream( class MyResponseStream(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - self._usage = Usage(request_tokens=300, response_tokens=400) + self._usage = RequestUsage(input_tokens=300, output_tokens=400) maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') if maybe_event is not None: # pragma: no branch yield maybe_event diff --git a/tests/models/test_mcp_sampling.py b/tests/models/test_mcp_sampling.py index 0336ce2f73..650ec34cbd 100644 --- a/tests/models/test_mcp_sampling.py +++ b/tests/models/test_mcp_sampling.py @@ -10,7 +10,6 @@ from pydantic_ai.agent import Agent from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import BinaryContent, ModelRequest, ModelResponse, SystemPromptPart, TextPart, UserPromptPart -from pydantic_ai.usage import Usage from ..conftest import IsNow, try_import @@ -59,7 +58,6 @@ def test_assistant_text(): ), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), model_name='test-model', timestamp=IsNow(tz=timezone.utc), ), @@ -93,14 +91,12 @@ def test_assistant_text_history(): ModelRequest(parts=[UserPromptPart(content='1', timestamp=IsNow(tz=timezone.utc))], instructions='testing'), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), model_name='test-model', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='2', timestamp=IsNow(tz=timezone.utc))], instructions='testing'), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), model_name='test-model', timestamp=IsNow(tz=timezone.utc), ), @@ -121,7 +117,6 @@ def test_assistant_text_history_complex(): ), ModelResponse( parts=[TextPart(content='text content')], - usage=Usage(requests=1), model_name='test-model', ), ] diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6736959726..c75fe176e0 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -29,7 +29,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream @@ -205,32 +205,30 @@ async def test_multiple_completions(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 result = await agent.run('hello again', message_history=result.new_messages()) assert result.output == 'hello again' - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=IsNow(tz=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello again')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -252,46 +250,43 @@ async def test_three_completions(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 result = await agent.run('hello again', message_history=result.all_messages()) assert result.output == 'hello again' - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 result = await agent.run('final message', message_history=result.all_messages()) assert result.output == 'final message' - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello again')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='final message')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -320,9 +315,8 @@ async def test_stream_text(allow_model_requests: None): ['hello ', 'hello world ', 'hello world welcome ', 'hello world welcome mistral'] ) assert result.is_complete - assert result.usage().request_tokens == 5 - assert result.usage().response_tokens == 5 - assert result.usage().total_tokens == 5 + assert result.usage().input_tokens == 5 + assert result.usage().output_tokens == 5 async def test_stream_text_finish_reason(allow_model_requests: None): @@ -357,9 +351,8 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage().request_tokens == 3 - assert result.usage().response_tokens == 3 - assert result.usage().total_tokens == 3 + assert result.usage().input_tokens == 3 + assert result.usage().output_tokens == 3 ##################### @@ -393,9 +386,8 @@ class CityLocation(BaseModel): result = await agent.run('User prompt value') assert result.output == CityLocation(city='paris', country='france') - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 2 - assert result.usage().total_tokens == 3 + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 2 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]), @@ -407,10 +399,10 @@ class CityLocation(BaseModel): tool_call_id='123', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3), + usage=RequestUsage(input_tokens=1, output_tokens=2), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -453,10 +445,9 @@ class CityLocation(BaseModel): result = await agent.run('User prompt value') assert result.output == CityLocation(city='paris', country='france') - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 - assert result.usage().details is None + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 + assert result.usage().details == {} assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]), @@ -468,10 +459,10 @@ class CityLocation(BaseModel): tool_call_id='123', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -508,10 +499,9 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque result = await agent.run('User prompt value') assert result.output == 42 - assert result.usage().request_tokens == 1 - assert result.usage().response_tokens == 1 - assert result.usage().total_tokens == 1 - assert result.usage().details is None + assert result.usage().input_tokens == 1 + assert result.usage().output_tokens == 1 + assert result.usage().details == {} assert result.all_messages() == snapshot( [ ModelRequest( @@ -528,10 +518,10 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque tool_call_id='123', ) ], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -646,12 +636,11 @@ class MyTypedDict(TypedDict, total=False): ] ) assert result.is_complete - assert result.usage().request_tokens == 10 - assert result.usage().response_tokens == 10 - assert result.usage().total_tokens == 10 + assert result.usage().input_tokens == 10 + assert result.usage().output_tokens == 10 # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_result_type_primitif_dict(allow_model_requests: None): @@ -733,12 +722,11 @@ class MyTypedDict(TypedDict, total=False): ] ) assert result.is_complete - assert result.usage().request_tokens == 34 - assert result.usage().response_tokens == 34 - assert result.usage().total_tokens == 34 + assert result.usage().input_tokens == 34 + assert result.usage().output_tokens == 34 # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_result_type_primitif_int(allow_model_requests: None): @@ -763,12 +751,11 @@ async def test_stream_result_type_primitif_int(allow_model_requests: None): v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot([1, 1, 1]) assert result.is_complete - assert result.usage().request_tokens == 6 - assert result.usage().response_tokens == 6 - assert result.usage().total_tokens == 6 + assert result.usage().input_tokens == 6 + assert result.usage().output_tokens == 6 # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_result_type_primitif_array(allow_model_requests: None): @@ -856,12 +843,11 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage().request_tokens == 35 - assert result.usage().response_tokens == 35 - assert result.usage().total_tokens == 35 + assert result.usage().input_tokens == 35 + assert result.usage().output_tokens == 35 # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_result_type_basemodel_with_default_params(allow_model_requests: None): @@ -941,12 +927,11 @@ class MyTypedBaseModel(BaseModel): ] ) assert result.is_complete - assert result.usage().request_tokens == 34 - assert result.usage().response_tokens == 34 - assert result.usage().total_tokens == 34 + assert result.usage().input_tokens == 34 + assert result.usage().output_tokens == 34 # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_result_type_basemodel_with_required_params(allow_model_requests: None): @@ -1010,12 +995,11 @@ class MyTypedBaseModel(BaseModel): ] ) assert result.is_complete - assert result.usage().request_tokens == 34 - assert result.usage().response_tokens == 34 - assert result.usage().total_tokens == 34 + assert result.usage().input_tokens == 34 + assert result.usage().output_tokens == 34 # double check cost matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) ##################### @@ -1077,8 +1061,8 @@ async def get_location(loc_name: str) -> str: result = await agent.run('Hello') assert result.output == 'final response' - assert result.usage().request_tokens == 6 - assert result.usage().response_tokens == 4 + assert result.usage().input_tokens == 6 + assert result.usage().output_tokens == 4 assert result.usage().total_tokens == 10 assert result.all_messages() == snapshot( [ @@ -1096,10 +1080,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RequestUsage(input_tokens=2, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1119,10 +1103,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), + usage=RequestUsage(input_tokens=3, output_tokens=2), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1136,10 +1120,10 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1220,9 +1204,8 @@ async def get_location(loc_name: str) -> str: result = await agent.run('Hello') assert result.output == {'lat': 51, 'lng': 0} - assert result.usage().request_tokens == 7 - assert result.usage().response_tokens == 4 - assert result.usage().total_tokens == 12 + assert result.usage().input_tokens == 7 + assert result.usage().output_tokens == 4 assert result.all_messages() == snapshot( [ ModelRequest( @@ -1239,10 +1222,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RequestUsage(input_tokens=2, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1262,10 +1245,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), + usage=RequestUsage(input_tokens=3, output_tokens=2), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1285,10 +1268,10 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + usage=RequestUsage(input_tokens=2, output_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -1362,12 +1345,11 @@ async def get_location(loc_name: str) -> str: assert v == snapshot([{'won': True}, {'won': True}]) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) - assert result.usage().request_tokens == 4 - assert result.usage().response_tokens == 4 - assert result.usage().total_tokens == 4 + assert result.usage().input_tokens == 4 + assert result.usage().output_tokens == 4 # double check usage matches stream count - assert result.usage().response_tokens == 4 + assert result.usage().output_tokens == 4 assert result.all_messages() == snapshot( [ @@ -1385,7 +1367,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RequestUsage(input_tokens=2, output_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1401,7 +1383,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"won": true}', tool_call_id='1')], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RequestUsage(input_tokens=2, output_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1464,12 +1446,11 @@ async def get_location(loc_name: str) -> str: assert v == snapshot(['final ', 'final response', 'final response']) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) - assert result.usage().request_tokens == 6 - assert result.usage().response_tokens == 6 - assert result.usage().total_tokens == 6 + assert result.usage().input_tokens == 6 + assert result.usage().output_tokens == 6 # double check usage matches stream count - assert result.usage().response_tokens == 6 + assert result.usage().output_tokens == 6 assert result.all_messages() == snapshot( [ @@ -1487,7 +1468,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RequestUsage(input_tokens=2, output_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1503,7 +1484,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(request_tokens=4, response_tokens=4, total_tokens=4), + usage=RequestUsage(input_tokens=4, output_tokens=4), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1569,12 +1550,11 @@ async def get_location(loc_name: str) -> str: assert v == snapshot(['final ', 'final response']) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) - assert result.usage().request_tokens == 7 - assert result.usage().response_tokens == 7 - assert result.usage().total_tokens == 7 + assert result.usage().input_tokens == 7 + assert result.usage().output_tokens == 7 # double check usage matches stream count - assert result.usage().response_tokens == 7 + assert result.usage().output_tokens == 7 assert result.all_messages() == snapshot( [ @@ -1592,7 +1572,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage(request_tokens=2, response_tokens=2, total_tokens=2), + usage=RequestUsage(input_tokens=2, output_tokens=2), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1614,7 +1594,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage(request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1630,7 +1610,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(request_tokens=4, response_tokens=4, total_tokens=4), + usage=RequestUsage(input_tokens=4, output_tokens=4), model_name='mistral-large-latest', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -1804,10 +1784,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='utZJMAZN4')], - usage=Usage(requests=1, request_tokens=65, response_tokens=16, total_tokens=81), + usage=RequestUsage(input_tokens=65, output_tokens=16), model_name='pixtral-12b-latest', timestamp=IsDatetime(), - vendor_id='fce6d16a4e5940edb24ae16dd0369947', + provider_request_id='fce6d16a4e5940edb24ae16dd0369947', ), ModelRequest( parts=[ @@ -1832,10 +1812,10 @@ async def get_image() -> BinaryContent: content='The image you\'re referring to, labeled as "file 1c8566," shows a kiwi. Kiwis are small, brown, oval-shaped fruits with a bright green flesh inside that is dotted with tiny black seeds. They have a sweet and tangy flavor and are known for being rich in vitamin C and fiber.' ) ], - usage=Usage(requests=1, request_tokens=2931, response_tokens=70, total_tokens=3001), + usage=RequestUsage(input_tokens=2931, output_tokens=70), model_name='pixtral-12b-latest', timestamp=IsDatetime(), - vendor_id='26e7de193646460e8904f8e604a60dc1', + provider_request_id='26e7de193646460e8904f8e604a60dc1', ), ] ) @@ -1870,10 +1850,10 @@ async def test_image_url_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1903,10 +1883,10 @@ async def test_image_as_binary_content_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1939,10 +1919,10 @@ async def test_pdf_url_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -1969,10 +1949,10 @@ async def test_pdf_as_binary_content_input(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -2045,10 +2025,10 @@ async def test_mistral_model_instructions(allow_model_requests: None, mistral_ap ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), + usage=RequestUsage(input_tokens=1, output_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -2072,15 +2052,10 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap ThinkingPart(content=IsStr(), id='rs_68079ad7f0588191af64f067e7314d840493b22e4095129c'), TextPart(content=IsStr()), ], - usage=Usage( - request_tokens=13, - response_tokens=1789, - total_tokens=1802, - details={'reasoning_tokens': 1344, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=1789, details={'reasoning_tokens': 1344}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', + provider_request_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', ), ] ) @@ -2135,15 +2110,10 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap """ ), ], - usage=Usage( - request_tokens=13, - response_tokens=1789, - total_tokens=1802, - details={'reasoning_tokens': 1344, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=1789, details={'reasoning_tokens': 1344}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', + provider_request_id='resp_68079acebbfc819189ec20e1e5bf525d0493b22e4095129c', ), ModelRequest( parts=[ @@ -2155,10 +2125,10 @@ async def test_mistral_model_thinking_part(allow_model_requests: None, openai_ap ), ModelResponse( parts=[TextPart(content=IsStr())], - usage=Usage(requests=1, request_tokens=1036, response_tokens=691, total_tokens=1727), + usage=RequestUsage(input_tokens=1036, output_tokens=691), model_name='mistral-large-latest', timestamp=IsDatetime(), - vendor_id='a088e80a476e44edaaa959a1ff08f358', + provider_request_id='a088e80a476e44edaaa959a1ff08f358', ), ] ) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 7e1e820af9..c51d9edf1f 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -22,7 +22,8 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage +from pydantic_ai.usage import RequestUsage from ..conftest import IsNow, IsStr @@ -66,7 +67,7 @@ def test_simple(): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content="content='Hello' part_kind='user-prompt' message_count=1")], - usage=Usage(requests=1, request_tokens=51, response_tokens=3, total_tokens=54), + usage=RequestUsage(input_tokens=51, output_tokens=3), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), ), @@ -80,14 +81,14 @@ def test_simple(): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content="content='Hello' part_kind='user-prompt' message_count=1")], - usage=Usage(requests=1, request_tokens=51, response_tokens=3, total_tokens=54), + usage=RequestUsage(input_tokens=51, output_tokens=3), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content="content='World' part_kind='user-prompt' message_count=3")], - usage=Usage(requests=1, request_tokens=52, response_tokens=6, total_tokens=58), + usage=RequestUsage(input_tokens=52, output_tokens=6), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), ), @@ -157,7 +158,7 @@ def test_weather(): tool_name='get_location', args='{"location_description": "London"}', tool_call_id=IsStr() ) ], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -173,7 +174,7 @@ def test_weather(): ), ModelResponse( parts=[ToolCallPart(tool_name='get_weather', args='{"lat": 51, "lng": 0}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=56, response_tokens=11, total_tokens=67), + usage=RequestUsage(input_tokens=56, output_tokens=11), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -189,7 +190,7 @@ def test_weather(): ), ModelResponse( parts=[TextPart(content='Raining in London')], - usage=Usage(requests=1, request_tokens=57, response_tokens=14, total_tokens=71), + usage=RequestUsage(input_tokens=57, output_tokens=14), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -357,7 +358,7 @@ def test_call_all(): ToolCallPart(tool_name='qux', args={'x': 0}, tool_call_id=IsStr()), ToolCallPart(tool_name='quz', args={'x': 'a'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=52, response_tokens=21, total_tokens=73), + usage=RequestUsage(input_tokens=52, output_tokens=21), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -382,7 +383,7 @@ def test_call_all(): ), ModelResponse( parts=[TextPart(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')], - usage=Usage(requests=1, request_tokens=57, response_tokens=33, total_tokens=90), + usage=RequestUsage(input_tokens=57, output_tokens=33), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -451,13 +452,13 @@ async def test_stream_text(): ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello world')], - usage=Usage(request_tokens=50, response_tokens=2, total_tokens=52), + usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::stream_text_function', timestamp=IsNow(tz=timezone.utc), ), ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=50, response_tokens=2, total_tokens=52)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=50, output_tokens=2)) class Foo(BaseModel): @@ -480,11 +481,10 @@ async def stream_structured_function( async with agent.run_stream('') as result: assert await result.get_output() == snapshot(Foo(x=1)) assert result.usage() == snapshot( - Usage( + RunUsage( requests=1, - request_tokens=50, - response_tokens=4, - total_tokens=54, + input_tokens=50, + output_tokens=4, ) ) diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 02aafd259f..4daf179f05 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -30,7 +30,7 @@ VideoUrl, ) from pydantic_ai.models.test import TestModel, _chars, _JsonSchemaTestData # pyright: ignore[reportPrivateUsage] -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage, RunUsage from ..conftest import IsNow, IsStr @@ -106,7 +106,7 @@ async def my_ret(x: int) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='my_ret', args={'x': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -122,7 +122,7 @@ async def my_ret(x: int) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='my_ret', args={'x': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=61, response_tokens=8, total_tokens=69), + usage=RequestUsage(input_tokens=61, output_tokens=8), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -135,7 +135,7 @@ async def my_ret(x: int) -> str: ), ModelResponse( parts=[TextPart(content='{"my_ret":"1"}')], - usage=Usage(requests=1, request_tokens=62, response_tokens=12, total_tokens=74), + usage=RequestUsage(input_tokens=62, output_tokens=12), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -339,4 +339,4 @@ def test_different_content_input(content: AudioUrl | VideoUrl | ImageUrl | Binar agent = Agent() result = agent.run_sync(['x', content], model=TestModel(custom_output_text='custom')) assert result.output == snapshot('custom') - assert result.usage() == snapshot(Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=1)) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index fb23282a5b..db0a3eaacc 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -42,9 +42,10 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, TestEnv, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream @@ -175,31 +176,29 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(RunUsage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) @@ -234,7 +233,13 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3)) + assert result.usage() == snapshot( + RunUsage( + requests=1, + input_tokens=2, + output_tokens=1, + ) + ) async def test_request_structured_response(allow_model_requests: None): @@ -268,10 +273,9 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -358,12 +362,14 @@ async def get_location(loc_name: str) -> str: tool_call_id='1', ) ], - usage=Usage( - requests=1, request_tokens=2, response_tokens=1, total_tokens=3, details={'cached_tokens': 1} + usage=RequestUsage( + input_tokens=2, + cache_read_tokens=1, + output_tokens=1, ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -383,12 +389,14 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - usage=Usage( - requests=1, request_tokens=3, response_tokens=2, total_tokens=6, details={'cached_tokens': 2} + usage=RequestUsage( + input_tokens=3, + cache_read_tokens=2, + output_tokens=2, ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ModelRequest( parts=[ @@ -402,22 +410,13 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), - vendor_id='123', + provider_request_id='123', ), ] ) - assert result.usage() == snapshot( - Usage( - requests=3, - request_tokens=5, - response_tokens=3, - total_tokens=9, - details={'cached_tokens': 3}, - ) - ) + assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3)) FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] @@ -450,7 +449,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=4, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -522,9 +521,9 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=11, request_tokens=20, response_tokens=10, total_tokens=30)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=20, output_tokens=10)) # double check usage matches stream count - assert result.usage().response_tokens == len(stream) + assert result.usage().output_tokens == len(stream) async def test_stream_structured_finish_reason(allow_model_requests: None): @@ -671,7 +670,7 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(requests=4, request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3)) @pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None]) @@ -786,6 +785,21 @@ async def test_openai_audio_url_input(allow_model_requests: None, openai_api_key assert result.output == snapshot( 'Yes, the phenomenon of the sun rising in the east and setting in the west is due to the rotation of the Earth. The Earth rotates on its axis from west to east, making the sun appear to rise on the eastern horizon and set in the west. This is a daily occurrence and has been a fundamental aspect of human observation and timekeeping throughout history.' ) + assert result.usage() == snapshot( + RunUsage( + input_tokens=81, + output_tokens=72, + input_audio_tokens=69, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'text_tokens': 72, + }, + requests=1, + ) + ) async def test_document_url_input(allow_model_requests: None, openai_api_key: str): @@ -820,22 +834,19 @@ async def get_image() -> ImageUrl: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_4hrT4QP9jfojtK69vGiFCFjG')], - usage=Usage( - requests=1, - request_tokens=46, - response_tokens=11, - total_tokens=57, + usage=RequestUsage( + input_tokens=46, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmTHlrARTzAHK1na9s80xDlQGYPX', + provider_request_id='chatcmpl-BRmTHlrARTzAHK1na9s80xDlQGYPX', ), ModelRequest( parts=[ @@ -858,22 +869,19 @@ async def get_image() -> ImageUrl: ), ModelResponse( parts=[TextPart(content='The image shows a potato.')], - usage=Usage( - requests=1, - request_tokens=503, - response_tokens=8, - total_tokens=511, + usage=RequestUsage( + input_tokens=503, + output_tokens=8, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmTI0Y2zmkGw27kLarhsmiFQTGxR', + provider_request_id='chatcmpl-BRmTI0Y2zmkGw27kLarhsmiFQTGxR', ), ] ) @@ -902,22 +910,19 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='call_Btn0GIzGr4ugNlLmkQghQUMY')], - usage=Usage( - requests=1, - request_tokens=46, - response_tokens=11, - total_tokens=57, + usage=RequestUsage( + input_tokens=46, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlkLhPc87BdohVobEJJCGq3rUAG2', + provider_request_id='chatcmpl-BRlkLhPc87BdohVobEJJCGq3rUAG2', ), ModelRequest( parts=[ @@ -938,22 +943,19 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The image shows a kiwi fruit.')], - usage=Usage( - requests=1, - request_tokens=1185, - response_tokens=9, - total_tokens=1194, + usage=RequestUsage( + input_tokens=1185, + output_tokens=9, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlkORPA5rXMV3uzcOcgK4eQFKCVW', + provider_request_id='chatcmpl-BRlkORPA5rXMV3uzcOcgK4eQFKCVW', ), ] ) @@ -977,6 +979,21 @@ async def test_audio_as_binary_content_input( result = await agent.run(['Whose name is mentioned in the audio?', audio_content]) assert result.output == snapshot('The name mentioned in the audio is Marcelo.') + assert result.usage() == snapshot( + RunUsage( + input_tokens=64, + output_tokens=9, + input_audio_tokens=44, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'text_tokens': 9, + }, + requests=1, + ) + ) async def test_document_as_binary_content_input( @@ -1824,22 +1841,19 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - requests=1, - request_tokens=24, - response_tokens=8, - total_tokens=32, + usage=RequestUsage( + input_tokens=24, + output_tokens=8, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BJjf61mLb9z5H45ClJzbx0UWKwjo1', + provider_request_id='chatcmpl-BJjf61mLb9z5H45ClJzbx0UWKwjo1', ), ] ) @@ -1871,22 +1885,19 @@ async def get_temperature(city: str) -> float: ), ModelResponse( parts=[ToolCallPart(tool_name='get_temperature', args='{"city":"Tokyo"}', tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=50, - response_tokens=15, - total_tokens=65, + usage=RequestUsage( + input_tokens=50, + output_tokens=15, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4.1-mini-2025-04-14', timestamp=IsDatetime(), - vendor_id='chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq', + provider_request_id='chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq', ), ModelRequest( parts=[ @@ -1898,22 +1909,19 @@ async def get_temperature(city: str) -> float: ), ModelResponse( parts=[TextPart(content='The temperature in Tokyo is currently 20.0 degrees Celsius.')], - usage=Usage( - requests=1, - request_tokens=75, - response_tokens=15, - total_tokens=90, + usage=RequestUsage( + input_tokens=75, + output_tokens=15, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4.1-mini-2025-04-14', timestamp=IsDatetime(), - vendor_id='chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP', + provider_request_id='chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP', ), ] ) @@ -1936,15 +1944,10 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034841ab2881918a8c210e3d988b9208c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=2050, - total_tokens=2063, - details={'reasoning_tokens': 1664, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=2050, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', + provider_request_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', ), ] ) @@ -1964,15 +1967,10 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034841ab2881918a8c210e3d988b9208c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=2050, - total_tokens=2063, - details={'reasoning_tokens': 1664, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=2050, details={'reasoning_tokens': 1664}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', + provider_request_id='resp_68034835d12481919c80a7fd8dbe6f7e08c845d2be9bcdd8', ), ModelRequest( parts=[ @@ -1989,15 +1987,10 @@ async def test_openai_responses_model_thinking_part(allow_model_requests: None, ThinkingPart(content=IsStr(), id='rs_68034858dc588191bc3a6801c23e728f08c845d2be9bcdd8'), IsInstance(TextPart), ], - usage=Usage( - request_tokens=424, - response_tokens=2033, - total_tokens=2457, - details={'reasoning_tokens': 1408, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=424, output_tokens=2033, details={'reasoning_tokens': 1408}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_6803484f19a88191b9ea975d7cfbbe8408c845d2be9bcdd8', + provider_request_id='resp_6803484f19a88191b9ea975d7cfbbe8408c845d2be9bcdd8', ), ] ) @@ -2021,15 +2014,10 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1900, - total_tokens=1913, - details={'reasoning_tokens': 1536, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=1900, details={'reasoning_tokens': 1536}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', + provider_request_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', ), ] ) @@ -2050,15 +2038,10 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api IsInstance(ThinkingPart), IsInstance(TextPart), ], - usage=Usage( - request_tokens=13, - response_tokens=1900, - total_tokens=1913, - details={'reasoning_tokens': 1536, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=13, output_tokens=1900, details={'reasoning_tokens': 1536}), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', + provider_request_id='resp_680797310bbc8191971fff5a405113940ed3ec3064b5efac', ), ModelRequest( parts=[ @@ -2070,22 +2053,19 @@ async def test_openai_model_thinking_part(allow_model_requests: None, openai_api ), ModelResponse( parts=[TextPart(content=IsStr())], - usage=Usage( - requests=1, - request_tokens=822, - response_tokens=2437, - total_tokens=3259, + usage=RequestUsage( + input_tokens=822, + output_tokens=2437, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 1792, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='o3-mini-2025-01-31', timestamp=IsDatetime(), - vendor_id='chatcmpl-BP7ocN6qxho4C1UzUJWnU5tPJno55', + provider_request_id='chatcmpl-BP7ocN6qxho4C1UzUJWnU5tPJno55', ), ] ) @@ -2144,8 +2124,8 @@ async def test_openai_instructions_with_logprobs(allow_model_requests: None): ) messages = result.all_messages() response = cast(Any, messages[1]) - assert response.vendor_details is not None - assert response.vendor_details['logprobs'] == [ + assert response.provider_details is not None + assert response.provider_details['logprobs'] == [ { 'token': 'world', 'logprob': -0.6931, @@ -2372,22 +2352,19 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=68, - response_tokens=12, - total_tokens=80, + usage=RequestUsage( + input_tokens=68, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', + provider_request_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', ), ModelRequest( parts=[ @@ -2407,22 +2384,19 @@ async def get_user_country() -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - requests=1, - request_tokens=89, - response_tokens=36, - total_tokens=125, + usage=RequestUsage( + input_tokens=89, + output_tokens=36, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', + provider_request_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', ), ModelRequest( parts=[ @@ -2467,22 +2441,19 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH') ], - usage=Usage( - requests=1, - request_tokens=42, - response_tokens=11, - total_tokens=53, + usage=RequestUsage( + input_tokens=42, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8', + provider_request_id='chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8', ), ModelRequest( parts=[ @@ -2496,22 +2467,19 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=Usage( - requests=1, - request_tokens=63, - response_tokens=10, - total_tokens=73, + usage=RequestUsage( + input_tokens=63, + output_tokens=10, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH', + provider_request_id='chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH', ), ] ) @@ -2549,22 +2517,19 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR') ], - usage=Usage( - requests=1, - request_tokens=71, - response_tokens=12, - total_tokens=83, + usage=RequestUsage( + input_tokens=71, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', + provider_request_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', ), ModelRequest( parts=[ @@ -2578,22 +2543,19 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( - requests=1, - request_tokens=92, - response_tokens=15, - total_tokens=107, + usage=RequestUsage( + input_tokens=92, + output_tokens=15, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', + provider_request_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', ), ] ) @@ -2633,22 +2595,19 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_SIttSeiOistt33Htj4oiHOOX') ], - usage=Usage( - requests=1, - request_tokens=160, - response_tokens=11, - total_tokens=171, + usage=RequestUsage( + input_tokens=160, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD', + provider_request_id='chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD', ), ModelRequest( parts=[ @@ -2666,22 +2625,19 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( - requests=1, - request_tokens=181, - response_tokens=25, - total_tokens=206, + usage=RequestUsage( + input_tokens=181, + output_tokens=25, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd', + provider_request_id='chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd', ), ] ) @@ -2724,22 +2680,19 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB') ], - usage=Usage( - requests=1, - request_tokens=109, - response_tokens=11, - total_tokens=120, + usage=RequestUsage( + input_tokens=109, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw', + provider_request_id='chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw', ), ModelRequest( parts=[ @@ -2760,22 +2713,19 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( - requests=1, - request_tokens=130, - response_tokens=11, - total_tokens=141, + usage=RequestUsage( + input_tokens=130, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0', + provider_request_id='chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0', ), ] ) @@ -2822,22 +2772,19 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09') ], - usage=Usage( - requests=1, - request_tokens=273, - response_tokens=11, - total_tokens=284, + usage=RequestUsage( + input_tokens=273, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR', + provider_request_id='chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR', ), ModelRequest( parts=[ @@ -2862,22 +2809,19 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( - requests=1, - request_tokens=294, - response_tokens=21, - total_tokens=315, + usage=RequestUsage( + input_tokens=294, + output_tokens=21, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC', + provider_request_id='chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC', ), ] ) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 7921addb71..b7c1a8aecd 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -29,7 +29,7 @@ from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.profiles.openai import openai_model_profile from pydantic_ai.tools import ToolDefinition -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsStr, TestEnv, try_import from ..parts_from_messages import part_types_from_messages @@ -197,15 +197,10 @@ async def get_location(loc_name: str) -> str: tool_call_id=IsStr(), ), ], - usage=Usage( - request_tokens=0, - response_tokens=0, - total_tokens=0, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67e547c48c9481918c5c4394464ce0c60ae6111e84dd5c08', + provider_request_id='resp_67e547c48c9481918c5c4394464ce0c60ae6111e84dd5c08', ), ModelRequest( parts=[ @@ -233,15 +228,10 @@ async def get_location(loc_name: str) -> str: """ ) ], - usage=Usage( - request_tokens=335, - response_tokens=44, - total_tokens=379, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=335, output_tokens=44, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67e547c5a2f08191802a1f43620f348503a2086afed73b47', + provider_request_id='resp_67e547c5a2f08191802a1f43620f348503a2086afed73b47', ), ] ) @@ -271,15 +261,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=40, - response_tokens=11, - total_tokens=51, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=40, output_tokens=11, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_681134d3aa3481919ca581a267db1e510fe7a5a4e2123dc3', + provider_request_id='resp_681134d3aa3481919ca581a267db1e510fe7a5a4e2123dc3', ), ModelRequest( parts=[ @@ -300,15 +285,10 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The fruit in the image is a kiwi.')], - usage=Usage( - request_tokens=1185, - response_tokens=11, - total_tokens=1196, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=1185, output_tokens=11, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_681134d53c48819198ce7b89db78dffd02cbfeaababb040c', + provider_request_id='resp_681134d53c48819198ce7b89db78dffd02cbfeaababb040c', ), ] ) @@ -436,15 +416,10 @@ async def test_openai_responses_model_builtin_tools(allow_model_requests: None, """ ) ], - usage=Usage( - request_tokens=320, - response_tokens=159, - total_tokens=479, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=320, output_tokens=159, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_689b7c90010c8196ac0efd68b021490f07450cfc2d48b975', + provider_request_id='resp_689b7c90010c8196ac0efd68b021490f07450cfc2d48b975', ), ] ) @@ -464,15 +439,10 @@ async def test_openai_responses_model_instructions(allow_model_requests: None, o ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.')], - usage=Usage( - request_tokens=24, - response_tokens=8, - total_tokens=32, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=24, output_tokens=8, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_67f3fdfd9fa08191a3d5825db81b8df6003bc73febb56d77', + provider_request_id='resp_67f3fdfd9fa08191a3d5825db81b8df6003bc73febb56d77', ), ] ) @@ -712,15 +682,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=62, - response_tokens=12, - total_tokens=74, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=62, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7', + provider_request_id='resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7', ), ModelRequest( parts=[ @@ -740,15 +705,10 @@ async def get_user_country() -> str: tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', ) ], - usage=Usage( - request_tokens=85, - response_tokens=20, - total_tokens=105, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=85, output_tokens=20, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560', + provider_request_id='resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560', ), ModelRequest( parts=[ @@ -794,15 +754,10 @@ async def get_user_country() -> str: parts=[ ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr') ], - usage=Usage( - request_tokens=36, - response_tokens=12, - total_tokens=48, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=36, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a', + provider_request_id='resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a', ), ModelRequest( parts=[ @@ -816,15 +771,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='The largest city in Mexico is Mexico City.')], - usage=Usage( - request_tokens=59, - response_tokens=11, - total_tokens=70, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=59, output_tokens=11, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d', + provider_request_id='resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d', ), ] ) @@ -861,15 +811,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=66, - response_tokens=12, - total_tokens=78, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=66, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329', + provider_request_id='resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329', ), ModelRequest( parts=[ @@ -883,15 +828,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( - request_tokens=89, - response_tokens=16, - total_tokens=105, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=89, output_tokens=16, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f0fde708192989000a62809c6e5020197534e39cc1f', + provider_request_id='resp_68477f0fde708192989000a62809c6e5020197534e39cc1f', ), ] ) @@ -930,15 +870,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=153, - response_tokens=12, - total_tokens=165, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=153, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563', + provider_request_id='resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563', ), ModelRequest( parts=[ @@ -956,15 +891,10 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( - request_tokens=176, - response_tokens=26, - total_tokens=202, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=176, output_tokens=26, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68477f119830819da162aa6e10552035061ad97e2eef7871', + provider_request_id='resp_68477f119830819da162aa6e10552035061ad97e2eef7871', ), ] ) @@ -1006,15 +936,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=107, - response_tokens=12, - total_tokens=119, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=107, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50', + provider_request_id='resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50', ), ModelRequest( parts=[ @@ -1035,15 +960,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage( - request_tokens=130, - response_tokens=12, - total_tokens=142, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=130, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115', + provider_request_id='resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115', ), ] ) @@ -1089,15 +1009,10 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], - usage=Usage( - request_tokens=283, - response_tokens=12, - total_tokens=295, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=283, output_tokens=12, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee', + provider_request_id='resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee', ), ModelRequest( parts=[ @@ -1122,15 +1037,10 @@ async def get_user_country() -> str: content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' ) ], - usage=Usage( - request_tokens=306, - response_tokens=22, - total_tokens=328, - details={'reasoning_tokens': 0, 'cached_tokens': 0}, - ), + usage=RequestUsage(input_tokens=306, output_tokens=22, details={'reasoning_tokens': 0}), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024', + provider_request_id='resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024', ), ] ) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index c13d1da54d..f72227f8bc 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -19,7 +19,7 @@ UserPromptPart, ) from pydantic_ai.models.function import AgentInfo, FunctionModel -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime, IsStr, try_import @@ -612,7 +612,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr() ) ], - usage=Usage(requests=1, request_tokens=52, response_tokens=7, total_tokens=59), + usage=RequestUsage(input_tokens=52, output_tokens=7), model_name='function:track_messages:', timestamp=IsDatetime(), ), diff --git a/tests/test_agent.py b/tests/test_agent.py index d063b10117..20b40b13a7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -48,12 +48,13 @@ from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, StructuredDict, ToolOutput from pydantic_ai.profiles import ModelProfile -from pydantic_ai.result import Usage +from pydantic_ai.result import RunUsage from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.combined import CombinedToolset from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset +from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -112,7 +113,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58), + usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -135,7 +136,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=87, response_tokens=14, total_tokens=101), + usage=RequestUsage(input_tokens=87, output_tokens=14), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -235,7 +236,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 41, "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58), + usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -251,7 +252,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=63, response_tokens=14, total_tokens=77), + usage=RequestUsage(input_tokens=63, output_tokens=14), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -293,7 +294,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), ), @@ -310,7 +311,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: parts=[ ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr()) ], - usage=Usage(requests=1, request_tokens=74, response_tokens=8, total_tokens=82), + usage=RequestUsage(input_tokens=74, output_tokens=8), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), ), @@ -856,7 +857,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=53, response_tokens=7, total_tokens=60), + usage=RequestUsage(input_tokens=53, output_tokens=7), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -878,7 +879,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=68, response_tokens=13, total_tokens=81), + usage=RequestUsage(input_tokens=68, output_tokens=13), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -932,7 +933,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='New York City')], - usage=Usage(requests=1, request_tokens=53, response_tokens=3, total_tokens=56), + usage=RequestUsage(input_tokens=53, output_tokens=3), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -947,7 +948,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='Mexico City')], - usage=Usage(requests=1, request_tokens=70, response_tokens=5, total_tokens=75), + usage=RequestUsage(input_tokens=70, output_tokens=5), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1120,7 +1121,7 @@ def say_world(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='world')], - usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:say_world:', timestamp=IsDatetime(), ), @@ -1180,7 +1181,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=52, response_tokens=6, total_tokens=58), + usage=RequestUsage(input_tokens=52, output_tokens=6), model_name='function:call_handoff_tool:', timestamp=IsDatetime(), ), @@ -1215,7 +1216,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=57, response_tokens=6, total_tokens=63), + usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1428,7 +1429,7 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], - usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63), + usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='function:return_city_location:', timestamp=IsDatetime(), ), @@ -1467,7 +1468,7 @@ class Foo(BaseModel): ), ModelResponse( parts=[TextPart(content='{"bar":"baz"}')], - usage=Usage(requests=1, request_tokens=56, response_tokens=4, total_tokens=60), + usage=RequestUsage(input_tokens=56, output_tokens=4), model_name='function:return_foo:', timestamp=IsDatetime(), ), @@ -1541,7 +1542,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: content='{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' ) ], - usage=Usage(requests=1, request_tokens=53, response_tokens=17, total_tokens=70), + usage=RequestUsage(input_tokens=53, output_tokens=17), model_name='function:return_foo_bar:', timestamp=IsDatetime(), ), @@ -1582,7 +1583,7 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], - usage=Usage(requests=1, request_tokens=56, response_tokens=5, total_tokens=61), + usage=RequestUsage(input_tokens=56, output_tokens=5), model_name='function:return_city_location:', timestamp=IsDatetime(), ), @@ -1604,7 +1605,7 @@ class CityLocation(BaseModel): ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], - usage=Usage(requests=1, request_tokens=85, response_tokens=12, total_tokens=97), + usage=RequestUsage(input_tokens=85, output_tokens=12), model_name='function:return_city_location:', timestamp=IsDatetime(), ), @@ -1665,7 +1666,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='{"city": "New York City"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=6, total_tokens=59), + usage=RequestUsage(input_tokens=53, output_tokens=6), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1680,7 +1681,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], - usage=Usage(requests=1, request_tokens=70, response_tokens=11, total_tokens=81), + usage=RequestUsage(input_tokens=70, output_tokens=11), model_name='function:call_tool:', timestamp=IsDatetime(), ), @@ -1708,7 +1709,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1721,7 +1722,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1740,7 +1741,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1753,14 +1754,14 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68), + usage=RequestUsage(input_tokens=55, output_tokens=13), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1769,9 +1770,7 @@ async def ret_a(x: str) -> str: assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] assert result2.output == snapshot('{"ret_a":"a-apple"}') assert result2._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] - assert result2.usage() == snapshot( - Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) - ) + assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -1799,7 +1798,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1812,14 +1811,14 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68), + usage=RequestUsage(input_tokens=55, output_tokens=13), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1828,9 +1827,7 @@ async def ret_a(x: str) -> str: assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] assert result3.output == snapshot('{"ret_a":"a-apple"}') assert result3._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] - assert result3.usage() == snapshot( - Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) - ) + assert result3.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) def test_run_with_history_new_structured(): @@ -1856,7 +1853,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1875,7 +1872,7 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1903,7 +1900,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=52, response_tokens=5, total_tokens=57), + usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1916,7 +1913,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=53, response_tokens=9, total_tokens=62), + usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1938,7 +1935,7 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72), + usage=RequestUsage(input_tokens=59, output_tokens=13), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -1957,9 +1954,7 @@ async def ret_a(x: str) -> str: assert result2.output == snapshot(Response(a=0)) assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] - assert result2.usage() == snapshot( - Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None) - ) + assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ @@ -1998,7 +1993,7 @@ def test_run_with_history_ending_on_model_request_and_no_user_prompt(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsDatetime(), ), @@ -2051,7 +2046,7 @@ def test_tool() -> str: ), ModelResponse( parts=[TextPart(content='Final response')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='function:simple_response:', timestamp=IsDatetime(), ), @@ -2135,7 +2130,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=2, total_tokens=53), + usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2151,7 +2146,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=65, response_tokens=4, total_tokens=69), + usage=RequestUsage(input_tokens=65, output_tokens=4), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2175,7 +2170,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=2, total_tokens=53), + usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2191,7 +2186,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='success')], - usage=Usage(requests=1, request_tokens=65, response_tokens=3, total_tokens=68), + usage=RequestUsage(input_tokens=65, output_tokens=3), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), @@ -2462,7 +2457,7 @@ def another_tool(y: int) -> int: ToolCallPart(tool_name='final_result', args={'value': 'second'}, tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=53, response_tokens=23, total_tokens=76), + usage=RequestUsage(input_tokens=53, output_tokens=23), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2551,7 +2546,7 @@ def another_tool(y: int) -> int: # pragma: no cover ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=58, response_tokens=18, total_tokens=76), + usage=RequestUsage(input_tokens=58, output_tokens=18), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2703,7 +2698,7 @@ async def get_location(loc_name: str) -> str: TextPart(content='foo'), ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=51, response_tokens=6, total_tokens=57), + usage=RequestUsage(input_tokens=51, output_tokens=6), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2719,7 +2714,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - usage=Usage(requests=1, request_tokens=56, response_tokens=8, total_tokens=64), + usage=RequestUsage(input_tokens=56, output_tokens=8), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), @@ -2743,7 +2738,7 @@ def test_nested_capture_run_messages() -> None: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -2766,7 +2761,7 @@ def test_double_capture_run_messages() -> None: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -2816,7 +2811,7 @@ async def func() -> str: ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2844,7 +2839,7 @@ async def func() -> str: ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2855,7 +2850,7 @@ async def func() -> str: ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=54, response_tokens=8, total_tokens=62), + usage=RequestUsage(input_tokens=54, output_tokens=8), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2897,7 +2892,7 @@ async def func(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2926,7 +2921,7 @@ async def func(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=53, response_tokens=4, total_tokens=57), + usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2937,7 +2932,7 @@ async def func(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], - usage=Usage(requests=1, request_tokens=54, response_tokens=8, total_tokens=62), + usage=RequestUsage(input_tokens=54, output_tokens=8), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', @@ -2964,7 +2959,7 @@ async def foobar(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -2980,7 +2975,7 @@ async def foobar(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"foobar":"inner agent result"}')], - usage=Usage(requests=1, request_tokens=54, response_tokens=11, total_tokens=65), + usage=RequestUsage(input_tokens=54, output_tokens=11), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -3057,17 +3052,20 @@ def test_binary_content_serializable(): { 'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}], 'usage': { - 'requests': 1, - 'request_tokens': 56, - 'response_tokens': 4, - 'total_tokens': 60, - 'details': None, + 'input_tokens': 56, + 'cache_write_tokens': 0, + 'cache_read_tokens': 0, + 'output_tokens': 4, + 'input_audio_tokens': 0, + 'cache_audio_read_tokens': 0, + 'output_audio_tokens': 0, + 'details': {}, }, 'model_name': 'test', - 'vendor_id': None, + 'provider_details': None, + 'provider_request_id': None, 'timestamp': IsStr(), 'kind': 'response', - 'vendor_details': None, }, ] ) @@ -3108,17 +3106,20 @@ def test_image_url_serializable(): { 'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}], 'usage': { - 'requests': 1, - 'request_tokens': 51, - 'response_tokens': 4, - 'total_tokens': 55, - 'details': None, + 'input_tokens': 51, + 'cache_write_tokens': 0, + 'cache_read_tokens': 0, + 'output_tokens': 4, + 'input_audio_tokens': 0, + 'cache_audio_read_tokens': 0, + 'output_audio_tokens': 0, + 'details': {}, }, 'model_name': 'test', 'timestamp': IsStr(), + 'provider_details': None, + 'provider_request_id': None, 'kind': 'response', - 'vendor_details': None, - 'vendor_id': None, }, ] ) @@ -3320,7 +3321,7 @@ def test_instructions_with_message_history(): ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=Usage(requests=1, request_tokens=56, response_tokens=4, total_tokens=60), + usage=RequestUsage(input_tokens=56, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -3375,7 +3376,7 @@ def my_tool(x: int) -> int: TextPart(content='foo'), ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3391,7 +3392,7 @@ def my_tool(x: int) -> int: TextPart(content='bar'), ToolCallPart(tool_name='my_tool', args={'x': 2}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=52, response_tokens=10, total_tokens=62), + usage=RequestUsage(input_tokens=52, output_tokens=10), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3404,7 +3405,7 @@ def my_tool(x: int) -> int: ), ModelResponse( parts=[], - usage=Usage(requests=1, request_tokens=53, response_tokens=10, total_tokens=63), + usage=RequestUsage(input_tokens=53, output_tokens=10), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3543,7 +3544,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), ), ], - usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58), + usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3568,7 +3569,7 @@ def analyze_data() -> ToolReturn: ), ModelResponse( parts=[TextPart(content='Analysis completed')], - usage=Usage(requests=1, request_tokens=70, response_tokens=6, total_tokens=76), + usage=RequestUsage(input_tokens=70, output_tokens=6), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3618,7 +3619,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), ), ], - usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58), + usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3635,7 +3636,7 @@ def analyze_data() -> ToolReturn: ), ModelResponse( parts=[TextPart(content='Analysis completed')], - usage=Usage(requests=1, request_tokens=58, response_tokens=6, total_tokens=64), + usage=RequestUsage(input_tokens=58, output_tokens=6), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), @@ -3911,7 +3912,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[ToolCallPart(tool_name='add_foo_tool', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=57, response_tokens=2, total_tokens=59), + usage=RequestUsage(input_tokens=57, output_tokens=2), model_name='function:respond:', timestamp=IsDatetime(), ), @@ -3927,7 +3928,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[ToolCallPart(tool_name='foo', tool_call_id=IsStr())], - usage=Usage(requests=1, request_tokens=60, response_tokens=4, total_tokens=64), + usage=RequestUsage(input_tokens=60, output_tokens=4), model_name='function:respond:', timestamp=IsDatetime(), ), @@ -3943,7 +3944,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ModelResponse( parts=[TextPart(content='Done')], - usage=Usage(requests=1, request_tokens=63, response_tokens=5, total_tokens=68), + usage=RequestUsage(input_tokens=63, output_tokens=5), model_name='function:respond:', timestamp=IsDatetime(), ), @@ -4002,7 +4003,7 @@ async def only_if_plan_presented( tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='test', timestamp=IsDatetime(), ), @@ -4024,7 +4025,7 @@ async def only_if_plan_presented( tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64), + usage=RequestUsage(input_tokens=52, output_tokens=12), model_name='test', timestamp=IsDatetime(), ), @@ -4239,7 +4240,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon ), ModelResponse( parts=[ThinkingPart(content='Let me think about this...')], - usage=Usage(requests=1, request_tokens=57, response_tokens=6, total_tokens=63), + usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4254,7 +4255,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon ), ModelResponse( parts=[TextPart(content='Final answer')], - usage=Usage(requests=1, request_tokens=75, response_tokens=8, total_tokens=83), + usage=RequestUsage(input_tokens=75, output_tokens=8), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4323,7 +4324,7 @@ def delete_file(ctx: RunContext[ApprovableToolsDeps], path: str) -> str: tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ], - usage=Usage(requests=1, request_tokens=57, response_tokens=12, total_tokens=69), + usage=RequestUsage(input_tokens=57, output_tokens=12), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4372,7 +4373,7 @@ def delete_file(ctx: RunContext[ApprovableToolsDeps], path: str) -> str: tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ], - usage=Usage(requests=1, request_tokens=57, response_tokens=12, total_tokens=69), + usage=RequestUsage(input_tokens=57, output_tokens=12), model_name='function:model_function:', timestamp=IsDatetime(), ), @@ -4394,7 +4395,7 @@ def delete_file(ctx: RunContext[ApprovableToolsDeps], path: str) -> str: ), ModelResponse( parts=[TextPart(content='OK')], - usage=Usage(requests=1, request_tokens=76, response_tokens=13, total_tokens=89), + usage=RequestUsage(input_tokens=76, output_tokens=13), model_name='function:model_function:', timestamp=IsDatetime(), ), diff --git a/tests/test_direct.py b/tests/test_direct.py index 3e57679174..a26c18c0b4 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -31,7 +31,7 @@ from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from .conftest import IsNow, IsStr @@ -45,7 +45,7 @@ async def test_model_request(): parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), ) ) @@ -64,7 +64,7 @@ async def test_model_request_tool_call(): parts=[ToolCallPart(tool_name='tool_name', args={}, tool_call_id=IsStr(regex='pyd_ai_.*'))], model_name='test', timestamp=IsNow(tz=timezone.utc), - usage=Usage(requests=1, request_tokens=51, response_tokens=2, total_tokens=53), + usage=RequestUsage(input_tokens=51, output_tokens=2), ) ) @@ -76,7 +76,7 @@ def test_model_request_sync(): parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc), - usage=Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), ) ) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 913542fbb0..1aa138935e 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -16,7 +16,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.tools import RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime @@ -63,7 +63,7 @@ def no_op_history_processor(messages: list[ModelMessage]) -> list[ModelMessage]: ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], - usage=Usage(requests=1, request_tokens=54, response_tokens=4, total_tokens=58), + usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), @@ -101,7 +101,7 @@ def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage] ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], - usage=Usage(requests=1, request_tokens=54, response_tokens=2, total_tokens=56), + usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), @@ -135,7 +135,7 @@ def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage] ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='hello')], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), @@ -166,7 +166,7 @@ def capture_messages_processor(messages: list[ModelMessage]) -> list[ModelMessag ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], - usage=Usage(requests=1, request_tokens=54, response_tokens=2, total_tokens=56), + usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 178d0b3627..d4146740b3 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -28,7 +28,7 @@ from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage, RunUsage from .conftest import IsDatetime, IsNow, IsStr, try_import @@ -67,7 +67,7 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: @pytest.fixture def run_context(model: Model) -> RunContext[int]: - return RunContext(deps=0, model=model, usage=Usage()) + return RunContext(deps=0, model=model, usage=RunUsage()) async def test_stdio_server(run_context: RunContext[int]): @@ -201,22 +201,19 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) tool_call_id='call_QssdxTGkPblTYHmyVES1tKBj', ) ], - usage=Usage( - requests=1, - request_tokens=195, - response_tokens=19, - total_tokens=214, + usage=RequestUsage( + input_tokens=195, + output_tokens=19, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlnvvqIPFofAtKqtQKMWZkgXhzlT', + provider_request_id='chatcmpl-BRlnvvqIPFofAtKqtQKMWZkgXhzlT', ), ModelRequest( parts=[ @@ -230,22 +227,19 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ), ModelResponse( parts=[TextPart(content='0 degrees Celsius is equal to 32 degrees Fahrenheit.')], - usage=Usage( - requests=1, - request_tokens=227, - response_tokens=13, - total_tokens=240, + usage=RequestUsage( + input_tokens=227, + output_tokens=13, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlnyjUo5wlyqvdNdM5I8vIWjo1qF', + provider_request_id='chatcmpl-BRlnyjUo5wlyqvdNdM5I8vIWjo1qF', ), ] ) @@ -337,22 +331,19 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): tool_call_id='call_m9goNwaHBbU926w47V7RtWPt', ) ], - usage=Usage( - requests=1, - request_tokens=194, - response_tokens=18, - total_tokens=212, + usage=RequestUsage( + input_tokens=194, + output_tokens=18, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlo3e1Ud2lnvkddMilmwC7LAemiy', + provider_request_id='chatcmpl-BRlo3e1Ud2lnvkddMilmwC7LAemiy', ), ModelRequest( parts=[ @@ -370,22 +361,19 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): content='The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' ) ], - usage=Usage( - requests=1, - request_tokens=234, - response_tokens=19, - total_tokens=253, + usage=RequestUsage( + input_tokens=234, + output_tokens=19, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlo41LxqBYgGKWgGrQn67fQacOLp', + provider_request_id='chatcmpl-BRlo41LxqBYgGKWgGrQn67fQacOLp', ), ] ) @@ -414,22 +402,19 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A tool_call_id='call_LaiWltzI39sdquflqeuF0EyE', ) ], - usage=Usage( - requests=1, - request_tokens=200, - response_tokens=12, - total_tokens=212, + usage=RequestUsage( + input_tokens=200, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo', + provider_request_id='chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo', ), ModelRequest( parts=[ @@ -443,22 +428,19 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ), ModelResponse( parts=[TextPart(content='The product name is "Pydantic AI".')], - usage=Usage( - requests=1, - request_tokens=224, - response_tokens=12, - total_tokens=236, + usage=RequestUsage( + input_tokens=224, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR', + provider_request_id='chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR', ), ] ) @@ -487,22 +469,19 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age tool_call_id='call_qi5GtBeIEyT7Y3yJvVFIi062', ) ], - usage=Usage( - requests=1, - request_tokens=305, - response_tokens=12, - total_tokens=317, + usage=RequestUsage( + input_tokens=305, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdHSFe0EykAOpf0LWZzsWAodIQzb', + provider_request_id='chatcmpl-BwdHSFe0EykAOpf0LWZzsWAodIQzb', ), ModelRequest( parts=[ @@ -516,22 +495,19 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age ), ModelResponse( parts=[TextPart(content='The product name is "Pydantic AI".')], - usage=Usage( - requests=1, - request_tokens=332, - response_tokens=11, - total_tokens=343, + usage=RequestUsage( + input_tokens=332, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdHTIlBZWzXJPBR8VTOdC4O57ZQA', + provider_request_id='chatcmpl-BwdHTIlBZWzXJPBR8VTOdC4O57ZQA', ), ] ) @@ -562,22 +538,19 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: tool_call_id='call_nFsDHYDZigO0rOHqmChZ3pmt', ) ], - usage=Usage( - requests=1, - request_tokens=191, - response_tokens=12, - total_tokens=203, + usage=RequestUsage( + input_tokens=191, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlo7KYJVXuNZ5lLLdYcKZDsX2CHb', + provider_request_id='chatcmpl-BRlo7KYJVXuNZ5lLLdYcKZDsX2CHb', ), ModelRequest( parts=[ @@ -596,22 +569,19 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: content='This is an image of a sliced kiwi with a vibrant green interior and black seeds.' ) ], - usage=Usage( - requests=1, - request_tokens=1332, - response_tokens=19, - total_tokens=1351, + usage=RequestUsage( + input_tokens=1332, + output_tokens=19, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloBGHh27w3fQKwxq4fX2cPuZJa9', + provider_request_id='chatcmpl-BRloBGHh27w3fQKwxq4fX2cPuZJa9', ), ] ) @@ -644,22 +614,19 @@ async def test_tool_returning_image_resource_link( tool_call_id='call_eVFgn54V9Nuh8Y4zvuzkYjUp', ) ], - usage=Usage( - requests=1, - request_tokens=305, - response_tokens=12, - total_tokens=317, + usage=RequestUsage( + input_tokens=305, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdHygYePH1mZgHo2Xxzib0Y7sId7', + provider_request_id='chatcmpl-BwdHygYePH1mZgHo2Xxzib0Y7sId7', ), ModelRequest( parts=[ @@ -678,22 +645,19 @@ async def test_tool_returning_image_resource_link( content='This is an image of a sliced kiwi fruit. It shows the green, seed-speckled interior with fuzzy brown skin around the edges.' ) ], - usage=Usage( - requests=1, - request_tokens=1452, - response_tokens=29, - total_tokens=1481, + usage=RequestUsage( + input_tokens=1452, + output_tokens=29, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BwdI2D2r9dvqq3pbsA0qgwKDEdTtD', + provider_request_id='chatcmpl-BwdI2D2r9dvqq3pbsA0qgwKDEdTtD', ), ] ) @@ -714,16 +678,12 @@ async def test_tool_returning_audio_resource( ), ModelResponse( parts=[ToolCallPart(tool_name='get_audio_resource', args={}, tool_call_id=IsStr())], - usage=Usage( - requests=1, - request_tokens=383, - response_tokens=12, - total_tokens=520, - details={'thoughts_tokens': 125, 'text_prompt_tokens': 383}, + usage=RequestUsage( + input_tokens=383, output_tokens=12, details={'thoughts_tokens': 125, 'text_prompt_tokens': 383} ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -738,16 +698,15 @@ async def test_tool_returning_audio_resource( ), ModelResponse( parts=[TextPart(content='The audio resource contains a voice saying "Hello, my name is Marcelo."')], - usage=Usage( - requests=1, - request_tokens=575, - response_tokens=15, - total_tokens=590, + usage=RequestUsage( + input_tokens=575, + output_tokens=15, + input_audio_tokens=144, details={'text_prompt_tokens': 431, 'audio_prompt_tokens': 144}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -778,16 +737,12 @@ async def test_tool_returning_audio_resource_link( ), ToolCallPart(tool_name='get_audio_resource_link', args={}, tool_call_id=IsStr()), ], - usage=Usage( - requests=1, - request_tokens=561, - response_tokens=41, - total_tokens=797, - details={'thoughts_tokens': 195, 'text_prompt_tokens': 561}, + usage=RequestUsage( + input_tokens=561, output_tokens=41, details={'thoughts_tokens': 195, 'text_prompt_tokens': 561} ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ModelRequest( parts=[ @@ -802,16 +757,15 @@ async def test_tool_returning_audio_resource_link( ), ModelResponse( parts=[TextPart(content='00:05')], - usage=Usage( - requests=1, - request_tokens=784, - response_tokens=5, - total_tokens=789, + usage=RequestUsage( + input_tokens=784, + output_tokens=5, + input_audio_tokens=144, details={'text_prompt_tokens': 640, 'audio_prompt_tokens': 144}, ), model_name='models/gemini-2.5-pro', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + provider_details={'finish_reason': 'STOP'}, ), ] ) @@ -840,22 +794,19 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im tool_call_id='call_Q7xG8CCG0dyevVfUS0ubsDdN', ) ], - usage=Usage( - requests=1, - request_tokens=190, - response_tokens=11, - total_tokens=201, + usage=RequestUsage( + input_tokens=190, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloGQJWIX0Qk7gtNzF4s2Fez0O29', + provider_request_id='chatcmpl-BRloGQJWIX0Qk7gtNzF4s2Fez0O29', ), ModelRequest( parts=[ @@ -876,22 +827,19 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ), ModelResponse( parts=[TextPart(content='Here is an image of a sliced kiwi on a white background.')], - usage=Usage( - requests=1, - request_tokens=1329, - response_tokens=15, - total_tokens=1344, + usage=RequestUsage( + input_tokens=1329, + output_tokens=15, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloJHR654fSD0fcvLWZxtKtn0pag', + provider_request_id='chatcmpl-BRloJHR654fSD0fcvLWZxtKtn0pag', ), ] ) @@ -914,22 +862,19 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[ToolCallPart(tool_name='get_dict', args='{}', tool_call_id='call_oqKviITBj8PwpQjGyUu4Zu5x')], - usage=Usage( - requests=1, - request_tokens=195, - response_tokens=11, - total_tokens=206, + usage=RequestUsage( + input_tokens=195, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloOs7Bb2tq8wJyy9Rv7SQ7L65a7', + provider_request_id='chatcmpl-BRloOs7Bb2tq8wJyy9Rv7SQ7L65a7', ), ModelRequest( parts=[ @@ -943,22 +888,19 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[TextPart(content='{"foo":"bar","baz":123}')], - usage=Usage( - requests=1, - request_tokens=222, - response_tokens=11, - total_tokens=233, + usage=RequestUsage( + input_tokens=222, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloPczU1HSCWnreyo21DdNtdOM7L', + provider_request_id='chatcmpl-BRloPczU1HSCWnreyo21DdNtdOM7L', ), ] ) @@ -989,22 +931,19 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): tool_call_id='call_rETXZWddAGZSHyVHAxptPGgc', ) ], - usage=Usage( - requests=1, - request_tokens=203, - response_tokens=15, - total_tokens=218, + usage=RequestUsage( + input_tokens=203, + output_tokens=15, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloSNg7aGSp1rXDkhInjMIUHKd7A', + provider_request_id='chatcmpl-BRloSNg7aGSp1rXDkhInjMIUHKd7A', ), ModelRequest( parts=[ @@ -1024,22 +963,19 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): tool_call_id='call_4xGyvdghYKHN8x19KWkRtA5N', ) ], - usage=Usage( - requests=1, - request_tokens=250, - response_tokens=15, - total_tokens=265, + usage=RequestUsage( + input_tokens=250, + output_tokens=15, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloTvSkFeX4DZKQLqfH9KbQkWlpt', + provider_request_id='chatcmpl-BRloTvSkFeX4DZKQLqfH9KbQkWlpt', ), ModelRequest( parts=[ @@ -1057,22 +993,19 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): content='I called the tool with the correct parameter, and it returned: "This is not an error."' ) ], - usage=Usage( - requests=1, - request_tokens=277, - response_tokens=22, - total_tokens=299, + usage=RequestUsage( + input_tokens=277, + output_tokens=22, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloU3MhnqNEqujs28a3ofRbs7VPF', + provider_request_id='chatcmpl-BRloU3MhnqNEqujs28a3ofRbs7VPF', ), ] ) @@ -1095,22 +1028,19 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[ToolCallPart(tool_name='get_none', args='{}', tool_call_id='call_mJTuQ2Cl5SaHPTJbIILEUhJC')], - usage=Usage( - requests=1, - request_tokens=193, - response_tokens=11, - total_tokens=204, + usage=RequestUsage( + input_tokens=193, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloX2RokWc9j9PAXAuNXGR73WNqY', + provider_request_id='chatcmpl-BRloX2RokWc9j9PAXAuNXGR73WNqY', ), ModelRequest( parts=[ @@ -1124,22 +1054,19 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), ModelResponse( parts=[TextPart(content='Hello! How can I assist you today?')], - usage=Usage( - requests=1, - request_tokens=212, - response_tokens=11, - total_tokens=223, + usage=RequestUsage( + input_tokens=212, + output_tokens=11, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloYWGujk8yE94gfVSsM1T1Ol2Ej', + provider_request_id='chatcmpl-BRloYWGujk8yE94gfVSsM1T1Ol2Ej', ), ] ) @@ -1170,22 +1097,19 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: tool_call_id='call_kL0TvjEVQBDGZrn1Zv7iNYOW', ) ], - usage=Usage( - requests=1, - request_tokens=195, - response_tokens=12, - total_tokens=207, + usage=RequestUsage( + input_tokens=195, + output_tokens=12, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRlobKLgm6vf79c9O8sloZaYx3coC', + provider_request_id='chatcmpl-BRlobKLgm6vf79c9O8sloZaYx3coC', ), ModelRequest( parts=[ @@ -1215,22 +1139,19 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: content='The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' ) ], - usage=Usage( - requests=1, - request_tokens=1355, - response_tokens=24, - total_tokens=1379, + usage=RequestUsage( + input_tokens=1355, + output_tokens=24, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, - 'cached_tokens': 0, }, ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BRloepWR5NJpTgSqFBGTSPeM1SWm8', + provider_request_id='chatcmpl-BRloepWR5NJpTgSqFBGTSPeM1SWm8', ), ] ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index bd483d068b..139937a030 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -36,8 +36,9 @@ from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, PromptedOutput, TextOutput -from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_ai.result import AgentStream, FinalResult, RunUsage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.usage import RequestUsage from pydantic_graph import End from .conftest import IsInt, IsNow, IsStr @@ -63,7 +64,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(request_tokens=51, response_tokens=0, total_tokens=51), + usage=RequestUsage(input_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -77,11 +78,10 @@ async def ret_a(x: str) -> str: ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=103, - response_tokens=5, - total_tokens=108, + input_tokens=103, + output_tokens=5, ) ) response = await result.get_output() @@ -93,7 +93,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], - usage=Usage(request_tokens=51, response_tokens=0, total_tokens=51), + usage=RequestUsage(input_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -106,18 +106,17 @@ async def ret_a(x: str) -> str: ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], - usage=Usage(request_tokens=52, response_tokens=11, total_tokens=63), + usage=RequestUsage(input_tokens=52, output_tokens=11), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=103, - response_tokens=11, - total_tokens=114, + input_tokens=103, + output_tokens=11, ) ) @@ -224,43 +223,43 @@ def upcase(text: str) -> str: [ ModelResponse( parts=[TextPart(content='The ')], - usage=Usage(request_tokens=51, response_tokens=1, total_tokens=52), + usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat ')], - usage=Usage(request_tokens=51, response_tokens=2, total_tokens=53), + usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat ')], - usage=Usage(request_tokens=51, response_tokens=3, total_tokens=54), + usage=RequestUsage(input_tokens=51, output_tokens=3), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on ')], - usage=Usage(request_tokens=51, response_tokens=4, total_tokens=55), + usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on the ')], - usage=Usage(request_tokens=51, response_tokens=5, total_tokens=56), + usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on the mat.')], - usage=Usage(request_tokens=51, response_tokens=7, total_tokens=58), + usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelResponse( parts=[TextPart(content='The cat sat on the mat.')], - usage=Usage(request_tokens=51, response_tokens=7, total_tokens=58), + usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -327,7 +326,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args='{"x": "hello"}', tool_call_id=IsStr())], - usage=Usage(request_tokens=50, response_tokens=5, total_tokens=55), + usage=RequestUsage(input_tokens=50, output_tokens=5), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -349,7 +348,7 @@ async def ret_a(x: str) -> str: ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args='{"x": "hello"}', tool_call_id=IsStr())], - usage=Usage(request_tokens=50, response_tokens=5, total_tokens=55), + usage=RequestUsage(input_tokens=50, output_tokens=5), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -371,7 +370,7 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=Usage(request_tokens=50, response_tokens=7, total_tokens=57), + usage=RequestUsage(input_tokens=50, output_tokens=7), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -424,7 +423,7 @@ async def ret_a(x: str) -> str: # pragma: no cover ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], - usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51), + usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), @@ -480,7 +479,7 @@ def another_tool(y: int) -> int: # pragma: no cover ToolCallPart(tool_name='regular_tool', args='{"x": 1}', tool_call_id=IsStr()), ToolCallPart(tool_name='another_tool', args='{"y": 2}', tool_call_id=IsStr()), ], - usage=Usage(request_tokens=50, response_tokens=10, total_tokens=60), + usage=RequestUsage(input_tokens=50, output_tokens=10), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), ), @@ -536,7 +535,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt ToolCallPart(tool_name='final_result', args='{"value": "first"}', tool_call_id=IsStr()), ToolCallPart(tool_name='final_result', args='{"value": "second"}', tool_call_id=IsStr()), ], - usage=Usage(request_tokens=50, response_tokens=8, total_tokens=58), + usage=RequestUsage(input_tokens=50, output_tokens=8), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), ), @@ -603,7 +602,7 @@ def another_tool(y: int) -> int: ToolCallPart(tool_name='final_result', args='{"value": "second"}', tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args='{"value": "???"}', tool_call_id=IsStr()), ], - usage=Usage(request_tokens=50, response_tokens=18, total_tokens=68), + usage=RequestUsage(input_tokens=50, output_tokens=18), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), ), @@ -712,7 +711,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-call', ), ], - usage=Usage(request_tokens=50, response_tokens=14, total_tokens=64), + usage=RequestUsage(input_tokens=50, output_tokens=14), model_name='function::sf', timestamp=IsNow(tz=datetime.timezone.utc), kind='response', @@ -783,7 +782,7 @@ def regular_tool(x: int) -> int: ), ModelResponse( parts=[ToolCallPart(tool_name='regular_tool', args={'x': 0}, tool_call_id=IsStr())], - usage=Usage(request_tokens=57, response_tokens=0, total_tokens=57), + usage=RequestUsage(input_tokens=57), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -796,7 +795,7 @@ def regular_tool(x: int) -> int: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'value': 'a'}, tool_call_id=IsStr())], - usage=Usage(request_tokens=58, response_tokens=4, total_tokens=62), + usage=RequestUsage(input_tokens=58, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -852,7 +851,7 @@ def output_validator_simple(data: str) -> str: stream: AgentStream messages: list[str] = [] - stream_usage: Usage | None = None + stream_usage: RunUsage | None = None async with agent.iter('Hello') as run: async for node in run: if agent.is_model_request_node(node): @@ -861,11 +860,7 @@ def output_validator_simple(data: str) -> str: messages.append(chunk) stream_usage = deepcopy(stream.usage()) assert run.next_node == End(data=FinalResult(output='The bat sat on the mat.', tool_name=None, tool_call_id=None)) - assert ( - run.usage() - == stream_usage - == Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None) - ) + assert run.usage() == stream_usage == RunUsage(requests=1, input_tokens=51, output_tokens=7) assert messages == [ '', @@ -902,9 +897,7 @@ def output_validator_simple(data: str) -> str: assert messages == [ ModelResponse( parts=[TextPart(content=text, part_kind='text')], - usage=Usage( - requests=0, request_tokens=IsInt(), response_tokens=IsInt(), total_tokens=IsInt(), details=None - ), + usage=RequestUsage(input_tokens=IsInt(), output_tokens=IsInt()), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', diff --git a/tests/test_tools.py b/tests/test_tools.py index 730f535ea0..8e1c3361b5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -31,7 +31,7 @@ from pydantic_ai.toolsets.deferred import DeferredToolset from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime, IsStr @@ -1385,7 +1385,7 @@ def get_price(fruit: str) -> ToolReturn: tool_call_id=IsStr(), ), ], - usage=Usage(requests=1, request_tokens=58, response_tokens=10, total_tokens=68), + usage=RequestUsage(input_tokens=58, output_tokens=10), model_name='function:llm:', timestamp=IsDatetime(), ), @@ -1417,7 +1417,7 @@ def get_price(fruit: str) -> ToolReturn: ), ModelResponse( parts=[TextPart(content='Done!')], - usage=Usage(requests=1, request_tokens=76, response_tokens=11, total_tokens=87), + usage=RequestUsage(input_tokens=76, output_tokens=11), model_name='function:llm:', timestamp=IsDatetime(), ), diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 81f6863bf5..7ee3282041 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -24,7 +24,7 @@ from pydantic_ai.toolsets.prefixed import PrefixedToolset from pydantic_ai.toolsets.prepared import PreparedToolset from pydantic_ai.toolsets.wrapper import WrapperToolset -from pydantic_ai.usage import Usage +from pydantic_ai.usage import RunUsage pytestmark = pytest.mark.anyio @@ -35,7 +35,7 @@ def build_run_context(deps: T, run_step: int = 0) -> RunContext[T]: return RunContext( deps=deps, model=TestModel(), - usage=Usage(), + usage=RunUsage(), prompt=None, messages=[], run_step=run_step, diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 95c292e9d4..ef86989923 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -7,6 +7,7 @@ import pytest from genai_prices import Usage as GenaiPricesUsage, calc_price from inline_snapshot import snapshot +from inline_snapshot.extra import warns from pydantic_ai import Agent, RunContext, UsageLimitExceeded from pydantic_ai.messages import ( @@ -17,7 +18,7 @@ UserPromptPart, ) from pydantic_ai.models.test import TestModel -from pydantic_ai.usage import Usage, UsageLimits +from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits from .conftest import IsNow, IsStr @@ -32,11 +33,9 @@ def test_genai_prices(): def test_request_token_limit() -> None: test_agent = Agent(TestModel()) - with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the request_tokens_limit of 5 (request_tokens=59)') - ): + with pytest.raises(UsageLimitExceeded, match=re.escape('Exceeded the input_tokens_limit of 5 (input_tokens=59)')): test_agent.run_sync( - 'Hello, this prompt exceeds the request tokens limit.', usage_limits=UsageLimits(request_tokens_limit=5) + 'Hello, this prompt exceeds the request tokens limit.', usage_limits=UsageLimits(input_tokens_limit=5) ) @@ -45,10 +44,8 @@ def test_response_token_limit() -> None: TestModel(custom_output_text='Unfortunately, this response exceeds the response tokens limit by a few!') ) - with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 5 (response_tokens=11)') - ): - test_agent.run_sync('Hello', usage_limits=UsageLimits(response_tokens_limit=5)) + with pytest.raises(UsageLimitExceeded, match=re.escape('Exceeded the output_tokens_limit of 5 (output_tokens=11)')): + test_agent.run_sync('Hello', usage_limits=UsageLimits(output_tokens_limit=5)) def test_total_token_limit() -> None: @@ -86,9 +83,9 @@ async def ret_a(x: str) -> str: succeeded = False with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') + UsageLimitExceeded, match=re.escape('Exceeded the output_tokens_limit of 10 (output_tokens=11)') ): - async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: + async with test_agent.run_stream('Hello', usage_limits=UsageLimits(output_tokens_limit=10)) as result: assert test_agent.name == 'test_agent' assert not result.is_complete assert result.all_messages() == snapshot( @@ -102,11 +99,7 @@ async def ret_a(x: str) -> str: tool_call_id=IsStr(), ) ], - usage=Usage( - request_tokens=51, - response_tokens=0, - total_tokens=51, - ), + usage=RequestUsage(input_tokens=51), model_name='test', timestamp=IsNow(tz=timezone.utc), ), @@ -123,11 +116,10 @@ async def ret_a(x: str) -> str: ] ) assert result.usage() == snapshot( - Usage( + RunUsage( requests=2, - request_tokens=103, - response_tokens=5, - total_tokens=108, + input_tokens=103, + output_tokens=5, ) ) succeeded = True @@ -144,7 +136,7 @@ def test_usage_so_far() -> None: test_agent.run_sync( 'Hello, this prompt exceeds the request tokens limit.', usage_limits=UsageLimits(total_tokens_limit=105), - usage=Usage(total_tokens=100), + usage=RunUsage(input_tokens=50, output_tokens=50), ) @@ -152,20 +144,20 @@ async def test_multi_agent_usage_no_incr(): delegate_agent = Agent(TestModel(), output_type=int) controller_agent1 = Agent(TestModel()) - run_1_usages: list[Usage] = [] + run_1_usages: list[RunUsage] = [] @controller_agent1.tool async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: delegate_result = await delegate_agent.run(sentence) delegate_usage = delegate_result.usage() run_1_usages.append(delegate_usage) - assert delegate_usage == snapshot(Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55)) + assert delegate_usage == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=4)) return delegate_result.output result1 = await controller_agent1.run('foobar') assert result1.output == snapshot('{"delegate_to_other_agent1":0}') run_1_usages.append(result1.usage()) - assert result1.usage() == snapshot(Usage(requests=2, request_tokens=103, response_tokens=13, total_tokens=116)) + assert result1.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=13)) controller_agent2 = Agent(TestModel()) @@ -173,12 +165,12 @@ async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int: delegate_result = await delegate_agent.run(sentence, usage=ctx.usage) delegate_usage = delegate_result.usage() - assert delegate_usage == snapshot(Usage(requests=2, request_tokens=102, response_tokens=9, total_tokens=111)) + assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9)) return delegate_result.output result2 = await controller_agent2.run('foobar') assert result2.output == snapshot('{"delegate_to_other_agent2":0}') - assert result2.usage() == snapshot(Usage(requests=3, request_tokens=154, response_tokens=17, total_tokens=171)) + assert result2.usage() == snapshot(RunUsage(requests=3, input_tokens=154, output_tokens=17)) # confirm the usage from result2 is the sum of the usage from result1 assert result2.usage() == functools.reduce(operator.add, run_1_usages) @@ -199,10 +191,58 @@ async def test_multi_agent_usage_sync(): @controller_agent.tool def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: - new_usage = Usage(requests=5, request_tokens=2, response_tokens=3, total_tokens=4) + new_usage = RunUsage(requests=5, input_tokens=2, output_tokens=3) ctx.usage.incr(new_usage) return 0 result = await controller_agent.run('foobar') assert result.output == snapshot('{"delegate_to_other_agent":0}') - assert result.usage() == snapshot(Usage(requests=7, request_tokens=105, response_tokens=16, total_tokens=120)) + assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16)) + + +def test_request_usage_basics(): + usage = RequestUsage() + assert usage.output_audio_tokens == 0 + assert usage.requests == 1 + + +def test_add_usages(): + usage = RunUsage( + requests=2, + input_tokens=10, + output_tokens=20, + cache_read_tokens=30, + cache_write_tokens=40, + input_audio_tokens=50, + cache_audio_read_tokens=60, + details={ + 'custom1': 10, + 'custom2': 20, + }, + ) + assert usage + usage == snapshot( + RunUsage( + requests=4, + input_tokens=20, + output_tokens=40, + cache_write_tokens=80, + cache_read_tokens=60, + input_audio_tokens=100, + cache_audio_read_tokens=120, + details={'custom1': 20, 'custom2': 40}, + ) + ) + assert usage + RunUsage() == usage + assert RunUsage() + RunUsage() == RunUsage() + + +def test_deprecated_usage_limits(): + with warns( + snapshot(['DeprecationWarning: `request_tokens_limit` is deprecated, use `input_tokens_limit` instead']) + ): + assert UsageLimits(input_tokens_limit=100).request_tokens_limit == 100 # type: ignore + + with warns( + snapshot(['DeprecationWarning: `response_tokens_limit` is deprecated, use `output_tokens_limit` instead']) + ): + assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore diff --git a/uv.lock b/uv.lock index e8c1988de8..d8c22f8b69 100644 --- a/uv.lock +++ b/uv.lock @@ -3466,6 +3466,7 @@ source = { editable = "pydantic_ai_slim" } dependencies = [ { name = "eval-type-backport" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "genai-prices" }, { name = "griffe" }, { name = "httpx" }, { name = "opentelemetry-api" }, @@ -3548,6 +3549,7 @@ requires-dist = [ { name = "eval-type-backport", specifier = ">=0.2.0" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, + { name = "genai-prices", specifier = ">=0.0.22" }, { name = "google-auth", marker = "extra == 'vertexai'", specifier = ">=2.36.0" }, { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.28.0" }, { name = "griffe", specifier = ">=1.3.2" },