diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 1383f58c66..38c92f71eb 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -81,7 +81,7 @@ async def main(): return await model_instance.request( messages, model_settings, - model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), + model_request_parameters or models.ModelRequestParameters(), ) @@ -193,7 +193,7 @@ async def main(): return model_instance.request_stream( messages, model_settings, - model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), + model_request_parameters or models.ModelRequestParameters(), ) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index c5a0d26724..508c88554b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -41,7 +41,7 @@ ) from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage @@ -390,6 +390,23 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar return model_request_parameters + def prepare_request( + self, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + """Prepare request inputs before they are passed to the provider. + + This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures + ``customize_request_parameters`` is applied to the resolved + [`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if + they need to customize the preparation flow further, but most implementations should simply call + ``self.prepare_request(...)`` at the start of their ``request`` (and related) methods. + """ + merged_settings = merge_model_settings(self.settings, model_settings) + customized_parameters = self.customize_request_parameters(model_request_parameters) + return merged_settings, customized_parameters + @property @abstractmethod def model_name(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 26cd3af956..6c6c99210e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -205,6 +205,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._messages_create( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) @@ -220,6 +224,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._messages_create( messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 40da9d262a..2e4a65c6ba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -264,6 +264,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, False, settings, model_request_parameters) model_response = await self._process_response(response) @@ -277,6 +281,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) yield BedrockStreamedResponse( diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 18f3b5e084..1a08db4348 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -165,6 +165,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) return model_response diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index cf24f7002f..c8430f5775 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -11,7 +11,6 @@ from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError -from ..settings import merge_model_settings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: @@ -78,10 +77,8 @@ async def request( exceptions: list[Exception] = [] for model in self.models: - customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) - merged_settings = merge_model_settings(model.settings, model_settings) try: - response = await model.request(messages, merged_settings, customized_model_request_parameters) + response = await model.request(messages, model_settings, model_request_parameters) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) @@ -105,14 +102,10 @@ async def request_stream( exceptions: list[Exception] = [] for model in self.models: - customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) - merged_settings = merge_model_settings(model.settings, model_settings) async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( - model.request_stream( - messages, merged_settings, customized_model_request_parameters, run_context - ) + model.request_stream(messages, model_settings, model_request_parameters, run_context) ) except Exception as exc: if self._fallback_on(exc): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index e9cc927735..c0c0d7d62f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -125,6 +125,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, @@ -154,6 +158,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 5f17800a78..fc0a2973be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -155,6 +155,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) async with self._make_request( messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: @@ -171,6 +175,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) async with self._make_request( messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 6e05675fa7..4e01ca18f8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -225,6 +225,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) @@ -236,6 +240,10 @@ async def count_tokens( model_request_parameters: ModelRequestParameters, ) -> usage.RequestUsage: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) contents, generation_config = await self._build_content_and_config( messages, model_settings, model_request_parameters @@ -291,6 +299,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, True, model_settings, model_request_parameters) yield await self._process_streamed_response(response, model_request_parameters) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index b05913d2d6..81907d3944 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -182,6 +182,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) try: response = await self._completions_create( messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters @@ -218,6 +222,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 23473fb7d6..ed5ecb83be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -166,6 +166,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) @@ -181,6 +185,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 57dc235da6..1ebc7e0dba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -352,8 +352,12 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: - with self._instrument(messages, model_settings, model_request_parameters) as finish: - response = await super().request(messages, model_settings, model_request_parameters) + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, + model_request_parameters, + ) + with self._instrument(messages, prepared_settings, prepared_parameters) as finish: + response = await self.wrapped.request(messages, model_settings, model_request_parameters) finish(response) return response @@ -365,10 +369,14 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - with self._instrument(messages, model_settings, model_request_parameters) as finish: + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, + model_request_parameters, + ) + with self._instrument(messages, prepared_settings, prepared_parameters) as finish: response_stream: StreamedResponse | None = None try: - async with super().request_stream( + async with self.wrapped.request_stream( messages, model_settings, model_request_parameters, run_context ) as response_stream: yield response_stream diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index d1f3ade1e0..7b54ba0f6d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -52,6 +52,8 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages) + + model_settings, _ = self.prepare_request(model_settings, model_request_parameters) model_settings = cast(MCPSamplingModelSettings, model_settings or {}) result = await self.session.create_message( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0c749f3c60..dec09c0b1c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -185,6 +185,10 @@ async def request( ) -> ModelResponse: """Make a non-streaming request to the model from Pydantic AI call.""" check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) @@ -201,6 +205,10 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._stream_completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 843716fcef..3ad09cb579 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -393,6 +393,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -408,6 +412,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -926,6 +934,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) @@ -940,6 +952,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._responses_create( messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index bb374c78a3..7a01e5324c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -110,6 +110,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) model_response.usage = _estimate_usage([*messages, model_response]) @@ -123,6 +127,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 4c91991cc1..3260cc7d65 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -46,6 +46,13 @@ async def request_stream( def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: return self.wrapped.customize_request_parameters(model_request_parameters) + def prepare_request( + self, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + return self.wrapped.prepare_request(model_settings, model_request_parameters) + @property def model_name(self) -> str: return self.wrapped.model_name diff --git a/tests/models/test_model_settings.py b/tests/models/test_model_settings.py index 4db59e741d..3bcbedbafd 100644 --- a/tests/models/test_model_settings.py +++ b/tests/models/test_model_settings.py @@ -2,7 +2,12 @@ from __future__ import annotations -from pydantic_ai import Agent, ModelMessage, ModelResponse, TextPart +import asyncio + +from pydantic_ai import Agent +from pydantic_ai.direct import model_request as direct_model_request +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel @@ -161,3 +166,39 @@ def capture_settings(messages: list[ModelMessage], agent_info: AgentInfo) -> Mod assert captured_settings is not None assert captured_settings.get('temperature') == 0.75 assert len(captured_settings) == 1 # Only one setting should be present + + +def test_direct_model_request_merges_model_settings(): + """Ensure direct requests merge model defaults with provided run settings.""" + + captured_settings = None + + async def capture(messages: list[ModelMessage], agent_info: AgentInfo) -> ModelResponse: + nonlocal captured_settings + captured_settings = agent_info.model_settings + return ModelResponse(parts=[TextPart('ok')]) + + model = FunctionModel( + capture, + settings=ModelSettings(max_tokens=50, temperature=0.3), + ) + + messages: list[ModelMessage] = [ModelRequest.user_text_prompt('hi')] + run_settings = ModelSettings(temperature=0.9, top_p=0.2) + + async def _run() -> ModelResponse: + return await direct_model_request( + model, + messages, + model_settings=run_settings, + model_request_parameters=ModelRequestParameters(), + ) + + response = asyncio.run(_run()) + + assert response.parts == [TextPart('ok')] + assert captured_settings == { + 'max_tokens': 50, + 'temperature': 0.9, + 'top_p': 0.2, + } diff --git a/tests/test_direct.py b/tests/test_direct.py index b1149e7239..a26c18c0b4 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -7,8 +7,16 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai import ( - Agent, +from pydantic_ai import Agent +from pydantic_ai.direct import ( + StreamedResponseSync, + _prepare_model, # pyright: ignore[reportPrivateUsage] + model_request, + model_request_stream, + model_request_stream_sync, + model_request_sync, +) +from pydantic_ai.messages import ( FinalResultEvent, ModelMessage, ModelRequest, @@ -19,14 +27,6 @@ TextPartDelta, ToolCallPart, ) -from pydantic_ai.direct import ( - StreamedResponseSync, - _prepare_model, # pyright: ignore[reportPrivateUsage] - model_request, - model_request_stream, - model_request_stream_sync, - model_request_sync, -) from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel