From 6c32c832307d25e105e80bc7c519b3a017bd42ed Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Sun, 21 Sep 2025 08:25:08 +0000 Subject: [PATCH 1/9] init --- pydantic_ai_slim/pydantic_ai/direct.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 1383f58c66..313933cd17 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -78,9 +78,15 @@ async def main(): The model response and token usage associated with the request. """ model_instance = _prepare_model(model, instrument) + + merged_model_settings = settings.merge_model_settings( + base=model_instance.settings, + overrides=model_settings, + ) + return await model_instance.request( messages, - model_settings, + merged_model_settings, model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), ) @@ -190,9 +196,15 @@ async def main(): A [stream response][pydantic_ai.models.StreamedResponse] async context manager. """ model_instance = _prepare_model(model, instrument) + + merged_model_settings = settings.merge_model_settings( + base=model_instance.settings, + overrides=model_settings, + ) + return model_instance.request_stream( messages, - model_settings, + merged_model_settings, model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), ) From 43971f0e98f5b58bd3ddd27e3e2161ea2fc3fc19 Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Sun, 21 Sep 2025 09:03:56 +0000 Subject: [PATCH 2/9] merge in `Model` subclass `.request` --- .../pydantic_ai/models/anthropic.py | 4 +- .../pydantic_ai/models/bedrock.py | 4 +- pydantic_ai_slim/pydantic_ai/models/cohere.py | 3 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 4 +- pydantic_ai_slim/pydantic_ai/models/google.py | 4 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 4 +- .../pydantic_ai/models/huggingface.py | 4 +- .../pydantic_ai/models/instrumented.py | 4 +- .../pydantic_ai/models/mcp_sampling.py | 4 +- .../pydantic_ai/models/mistral.py | 4 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +- tests/test_direct.py | 111 ++++++++++++++++-- 12 files changed, 133 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index f14ccfe51e..f41e74b152 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -39,7 +39,7 @@ from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider from ..providers.anthropic import AsyncAnthropicClient -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -205,6 +205,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._messages_create( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) @@ -220,6 +221,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._messages_create( messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 40da9d262a..45188b8fea 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -43,7 +43,7 @@ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item from pydantic_ai.providers import Provider, infer_provider from pydantic_ai.providers.bedrock import BedrockModelProfile -from pydantic_ai.settings import ModelSettings +from pydantic_ai.settings import ModelSettings, merge_model_settings from pydantic_ai.tools import ToolDefinition if TYPE_CHECKING: @@ -264,6 +264,7 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings = merge_model_settings(self.settings, model_settings) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, False, settings, model_request_parameters) model_response = await self._process_response(response) @@ -277,6 +278,7 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings = merge_model_settings(self.settings, model_settings) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) yield BedrockStreamedResponse( diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 18f3b5e084..d52bf979b6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -28,7 +28,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, check_allow_model_requests @@ -165,6 +165,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) return model_response diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 5f17800a78..d5e1483ab9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -38,7 +38,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -155,6 +155,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) async with self._make_request( messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: @@ -171,6 +172,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) async with self._make_request( messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 18ddecca02..869701df70 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -37,7 +37,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import ( Model, @@ -225,6 +225,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) @@ -291,6 +292,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, True, model_settings, model_request_parameters) yield await self._process_streamed_response(response, model_request_parameters) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index b05913d2d6..868c2f138d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -41,7 +41,7 @@ from ..profiles import ModelProfile, ModelProfileSpec from ..profiles.groq import GroqModelProfile from ..providers import Provider, infer_provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import ( Model, @@ -182,6 +182,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) try: response = await self._completions_create( messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters @@ -218,6 +219,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._completions_create( messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 23473fb7d6..0215166514 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -38,7 +38,7 @@ ) from ..profiles import ModelProfile, ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import ( Model, @@ -166,6 +166,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._completions_create( messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) @@ -181,6 +182,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._completions_create( messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 57dc235da6..9adff38c7e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -29,7 +29,7 @@ ModelResponse, SystemPromptPart, ) -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse from .wrapper import WrapperModel @@ -352,6 +352,7 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings = merge_model_settings(self.settings, model_settings) with self._instrument(messages, model_settings, model_request_parameters) as finish: response = await super().request(messages, model_settings, model_request_parameters) finish(response) @@ -365,6 +366,7 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings = merge_model_settings(self.settings, model_settings) with self._instrument(messages, model_settings, model_request_parameters) as finish: response_stream: StreamedResponse | None = None try: diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index d1f3ade1e0..f0c5116c58 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -8,7 +8,7 @@ from .. import _mcp, exceptions from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from . import Model, ModelRequestParameters, StreamedResponse if TYPE_CHECKING: @@ -52,6 +52,8 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages) + + model_settings = merge_model_settings(self.settings, model_settings) model_settings = cast(MCPSamplingModelSettings, model_settings or {}) result = await self.session.create_message( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0c749f3c60..b320cfc9b0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -38,7 +38,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage from . import ( @@ -185,6 +185,7 @@ async def request( ) -> ModelResponse: """Make a non-streaming request to the model from Pydantic AI call.""" check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) @@ -201,6 +202,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._stream_completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index fdb2be686f..53aaf39c78 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -45,7 +45,7 @@ from ..profiles import ModelProfile, ModelProfileSpec from ..profiles.openai import OpenAIModelProfile, OpenAISystemPromptRole from ..providers import Provider, infer_provider -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -393,6 +393,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._completions_create( messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -408,6 +409,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._completions_create( messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -877,6 +879,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) @@ -891,6 +894,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings = merge_model_settings(self.settings, model_settings) response = await self._responses_create( messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) diff --git a/tests/test_direct.py b/tests/test_direct.py index b1149e7239..e67cdb2ae8 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -1,14 +1,25 @@ import asyncio import re -from contextlib import contextmanager +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager, contextmanager from datetime import timezone +from typing import Any from unittest.mock import AsyncMock, patch import pytest from inline_snapshot import snapshot -from pydantic_ai import ( - Agent, +from pydantic_ai import Agent +from pydantic_ai._run_context import RunContext +from pydantic_ai.direct import ( + StreamedResponseSync, + _prepare_model, # pyright: ignore[reportPrivateUsage] + model_request, + model_request_stream, + model_request_stream_sync, + model_request_sync, +) +from pydantic_ai.messages import ( FinalResultEvent, ModelMessage, ModelRequest, @@ -19,17 +30,10 @@ TextPartDelta, ToolCallPart, ) -from pydantic_ai.direct import ( - StreamedResponseSync, - _prepare_model, # pyright: ignore[reportPrivateUsage] - model_request, - model_request_stream, - model_request_stream_sync, - model_request_sync, -) -from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.models import ModelRequestParameters, StreamedResponse from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel +from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage @@ -38,6 +42,38 @@ pytestmark = pytest.mark.anyio +class RecordingTestModel(TestModel): + def __init__(self, *, settings: ModelSettings | None = None): + super().__init__(settings=settings) + self.recorded_model_settings: ModelSettings | None = None + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + self.recorded_model_settings = model_settings + return await super().request(messages, model_settings, model_request_parameters) + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + self.recorded_model_settings = model_settings + async with super().request_stream( + messages, + model_settings, + model_request_parameters, + run_context=run_context, + ) as stream: + yield stream + + async def test_model_request(): model_response = await model_request('test', [ModelRequest.user_text_prompt('x')]) assert model_response == snapshot( @@ -50,6 +86,41 @@ async def test_model_request(): ) +async def test_model_request_merges_model_settings_from_instance(): + base_settings: ModelSettings = {'temperature': 0.3, 'seed': 7} + override_settings: ModelSettings = {'temperature': 0.9, 'max_tokens': 20} + model = RecordingTestModel(settings=base_settings) + + await model_request( + model, + [ModelRequest.user_text_prompt('x')], + model_settings=override_settings, + instrument=False, + ) + + assert model.recorded_model_settings == {'temperature': 0.9, 'seed': 7, 'max_tokens': 20} + + +async def test_model_request_string_model_uses_override_settings(monkeypatch: pytest.MonkeyPatch): + override_settings: ModelSettings = {'temperature': 0.2, 'max_tokens': 5} + recording_model = RecordingTestModel() + + def fake_infer_model(model_identifier: object) -> RecordingTestModel: + assert model_identifier == 'test' + return recording_model + + monkeypatch.setattr('pydantic_ai.direct.models.infer_model', fake_infer_model) + + await model_request( + 'test', + [ModelRequest.user_text_prompt('x')], + model_settings=override_settings, + instrument=False, + ) + + assert recording_model.recorded_model_settings == override_settings + + async def test_model_request_tool_call(): model_response = await model_request( 'test', @@ -115,6 +186,22 @@ async def test_model_request_stream(): ) +async def test_model_request_stream_merges_model_settings_from_instance(): + base_settings: ModelSettings = {'temperature': 0.1, 'seed': 42} + override_settings: ModelSettings = {'temperature': 0.5, 'max_tokens': 100} + model = RecordingTestModel(settings=base_settings) + + async with model_request_stream( + model, + [ModelRequest.user_text_prompt('x')], + model_settings=override_settings, + instrument=False, + ) as stream: + _ = [chunk async for chunk in stream] + + assert model.recorded_model_settings == {'temperature': 0.5, 'seed': 42, 'max_tokens': 100} + + def test_model_request_stream_sync_without_context_manager(): """Test that accessing properties or iterating without context manager raises RuntimeError.""" messages: list[ModelMessage] = [ModelRequest.user_text_prompt('x')] From db1b6425daa5f0239c1cddb0f1386c5314eb8277 Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Mon, 22 Sep 2025 10:58:47 +0000 Subject: [PATCH 3/9] merge in test model --- pydantic_ai_slim/pydantic_ai/models/test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index bb374c78a3..bdbff73ddf 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -29,7 +29,7 @@ ToolReturnPart, ) from ..profiles import ModelProfileSpec -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage from . import Model, ModelRequestParameters, StreamedResponse @@ -111,6 +111,7 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: self.last_model_request_parameters = model_request_parameters + model_settings = merge_model_settings(self.settings, model_settings) model_response = self._request(messages, model_settings, model_request_parameters) model_response.usage = _estimate_usage([*messages, model_response]) return model_response @@ -124,6 +125,7 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: self.last_model_request_parameters = model_request_parameters + model_settings = merge_model_settings(self.settings, model_settings) model_response = self._request(messages, model_settings, model_request_parameters) yield TestStreamedResponse( From 0c371f64eb4cd3f4653d1aa55fcc8943f50f19b5 Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Mon, 22 Sep 2025 11:01:15 +0000 Subject: [PATCH 4/9] don't merge in `direct.py` --- pydantic_ai_slim/pydantic_ai/direct.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 313933cd17..1383f58c66 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -78,15 +78,9 @@ async def main(): The model response and token usage associated with the request. """ model_instance = _prepare_model(model, instrument) - - merged_model_settings = settings.merge_model_settings( - base=model_instance.settings, - overrides=model_settings, - ) - return await model_instance.request( messages, - merged_model_settings, + model_settings, model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), ) @@ -196,15 +190,9 @@ async def main(): A [stream response][pydantic_ai.models.StreamedResponse] async context manager. """ model_instance = _prepare_model(model, instrument) - - merged_model_settings = settings.merge_model_settings( - base=model_instance.settings, - overrides=model_settings, - ) - return model_instance.request_stream( messages, - merged_model_settings, + model_settings, model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), ) From 7f17245cc07efbd24578ecc85fbe093e81710c9d Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Mon, 22 Sep 2025 11:13:38 +0000 Subject: [PATCH 5/9] rm direct test --- tests/test_direct.py | 91 +------------------------------------------- 1 file changed, 2 insertions(+), 89 deletions(-) diff --git a/tests/test_direct.py b/tests/test_direct.py index e67cdb2ae8..a26c18c0b4 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -1,16 +1,13 @@ import asyncio import re -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from datetime import timezone -from typing import Any from unittest.mock import AsyncMock, patch import pytest from inline_snapshot import snapshot from pydantic_ai import Agent -from pydantic_ai._run_context import RunContext from pydantic_ai.direct import ( StreamedResponseSync, _prepare_model, # pyright: ignore[reportPrivateUsage] @@ -30,10 +27,9 @@ TextPartDelta, ToolCallPart, ) -from pydantic_ai.models import ModelRequestParameters, StreamedResponse +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel -from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage @@ -42,38 +38,6 @@ pytestmark = pytest.mark.anyio -class RecordingTestModel(TestModel): - def __init__(self, *, settings: ModelSettings | None = None): - super().__init__(settings=settings) - self.recorded_model_settings: ModelSettings | None = None - - async def request( - self, - messages: list[ModelMessage], - model_settings: ModelSettings | None, - model_request_parameters: ModelRequestParameters, - ) -> ModelResponse: - self.recorded_model_settings = model_settings - return await super().request(messages, model_settings, model_request_parameters) - - @asynccontextmanager - async def request_stream( - self, - messages: list[ModelMessage], - model_settings: ModelSettings | None, - model_request_parameters: ModelRequestParameters, - run_context: RunContext[Any] | None = None, - ) -> AsyncIterator[StreamedResponse]: - self.recorded_model_settings = model_settings - async with super().request_stream( - messages, - model_settings, - model_request_parameters, - run_context=run_context, - ) as stream: - yield stream - - async def test_model_request(): model_response = await model_request('test', [ModelRequest.user_text_prompt('x')]) assert model_response == snapshot( @@ -86,41 +50,6 @@ async def test_model_request(): ) -async def test_model_request_merges_model_settings_from_instance(): - base_settings: ModelSettings = {'temperature': 0.3, 'seed': 7} - override_settings: ModelSettings = {'temperature': 0.9, 'max_tokens': 20} - model = RecordingTestModel(settings=base_settings) - - await model_request( - model, - [ModelRequest.user_text_prompt('x')], - model_settings=override_settings, - instrument=False, - ) - - assert model.recorded_model_settings == {'temperature': 0.9, 'seed': 7, 'max_tokens': 20} - - -async def test_model_request_string_model_uses_override_settings(monkeypatch: pytest.MonkeyPatch): - override_settings: ModelSettings = {'temperature': 0.2, 'max_tokens': 5} - recording_model = RecordingTestModel() - - def fake_infer_model(model_identifier: object) -> RecordingTestModel: - assert model_identifier == 'test' - return recording_model - - monkeypatch.setattr('pydantic_ai.direct.models.infer_model', fake_infer_model) - - await model_request( - 'test', - [ModelRequest.user_text_prompt('x')], - model_settings=override_settings, - instrument=False, - ) - - assert recording_model.recorded_model_settings == override_settings - - async def test_model_request_tool_call(): model_response = await model_request( 'test', @@ -186,22 +115,6 @@ async def test_model_request_stream(): ) -async def test_model_request_stream_merges_model_settings_from_instance(): - base_settings: ModelSettings = {'temperature': 0.1, 'seed': 42} - override_settings: ModelSettings = {'temperature': 0.5, 'max_tokens': 100} - model = RecordingTestModel(settings=base_settings) - - async with model_request_stream( - model, - [ModelRequest.user_text_prompt('x')], - model_settings=override_settings, - instrument=False, - ) as stream: - _ = [chunk async for chunk in stream] - - assert model.recorded_model_settings == {'temperature': 0.5, 'seed': 42, 'max_tokens': 100} - - def test_model_request_stream_sync_without_context_manager(): """Test that accessing properties or iterating without context manager raises RuntimeError.""" messages: list[ModelMessage] = [ModelRequest.user_text_prompt('x')] From fea3f51275b2fc898520cb9000d98ef03971da6e Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Mon, 22 Sep 2025 11:27:10 +0000 Subject: [PATCH 6/9] function model --- pydantic_ai_slim/pydantic_ai/models/function.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index e9cc927735..eb1a8f966f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -32,7 +32,7 @@ UserPromptPart, ) from ..profiles import ModelProfile, ModelProfileSpec -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse @@ -125,6 +125,7 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings = merge_model_settings(self.settings, model_settings) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, @@ -154,6 +155,7 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings = merge_model_settings(self.settings, model_settings) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, From 8c436afdb6bb9bab15f49b46bc79f1fa26c25321 Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Tue, 30 Sep 2025 09:35:38 +0200 Subject: [PATCH 7/9] add `Model.prepare_request` to merge + customize --- pydantic_ai_slim/pydantic_ai/direct.py | 4 +- .../pydantic_ai/models/__init__.py | 20 ++++++++- .../pydantic_ai/models/anthropic.py | 12 ++++-- .../pydantic_ai/models/bedrock.py | 12 ++++-- pydantic_ai_slim/pydantic_ai/models/cohere.py | 7 ++- .../pydantic_ai/models/fallback.py | 11 +---- .../pydantic_ai/models/function.py | 12 ++++-- pydantic_ai_slim/pydantic_ai/models/gemini.py | 12 ++++-- pydantic_ai_slim/pydantic_ai/models/google.py | 16 +++++-- pydantic_ai_slim/pydantic_ai/models/groq.py | 12 ++++-- .../pydantic_ai/models/huggingface.py | 12 ++++-- .../pydantic_ai/models/instrumented.py | 20 ++++++--- .../pydantic_ai/models/mcp_sampling.py | 4 +- .../pydantic_ai/models/mistral.py | 12 ++++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 22 +++++++--- pydantic_ai_slim/pydantic_ai/models/test.py | 12 ++++-- .../pydantic_ai/models/wrapper.py | 7 +++ tests/models/test_model_settings.py | 43 ++++++++++++++++++- 18 files changed, 194 insertions(+), 56 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 1383f58c66..38c92f71eb 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -81,7 +81,7 @@ async def main(): return await model_instance.request( messages, model_settings, - model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), + model_request_parameters or models.ModelRequestParameters(), ) @@ -193,7 +193,7 @@ async def main(): return model_instance.request_stream( messages, model_settings, - model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), + model_request_parameters or models.ModelRequestParameters(), ) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 05292156a0..07cef3cf8c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -41,7 +41,7 @@ from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec from ..profiles._json_schema import JsonSchemaTransformer -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage @@ -390,6 +390,24 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar return model_request_parameters + def prepare_request( + self, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters | None, + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + """Prepare request inputs before they are passed to the provider. + + This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures + ``customize_request_parameters`` is applied to the resolved + [`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if + they need to customize the preparation flow further, but most implementations should simply call + ``self.prepare_request(...)`` at the start of their ``request`` (and related) methods. + """ + merged_settings = merge_model_settings(self.settings, model_settings) + resolved_parameters = model_request_parameters or ModelRequestParameters() + customized_parameters = self.customize_request_parameters(resolved_parameters) + return merged_settings, customized_parameters + @property @abstractmethod def model_name(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index f41e74b152..55ae6b0f39 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -39,7 +39,7 @@ from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider from ..providers.anthropic import AsyncAnthropicClient -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -205,7 +205,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._messages_create( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) @@ -221,7 +224,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._messages_create( messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 45188b8fea..2e4a65c6ba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -43,7 +43,7 @@ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item from pydantic_ai.providers import Provider, infer_provider from pydantic_ai.providers.bedrock import BedrockModelProfile -from pydantic_ai.settings import ModelSettings, merge_model_settings +from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition if TYPE_CHECKING: @@ -264,7 +264,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, False, settings, model_request_parameters) model_response = await self._process_response(response) @@ -278,7 +281,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) yield BedrockStreamedResponse( diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index d52bf979b6..1a08db4348 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -28,7 +28,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, check_allow_model_requests @@ -165,7 +165,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) return model_response diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index cf24f7002f..c8430f5775 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -11,7 +11,6 @@ from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError -from ..settings import merge_model_settings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: @@ -78,10 +77,8 @@ async def request( exceptions: list[Exception] = [] for model in self.models: - customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) - merged_settings = merge_model_settings(model.settings, model_settings) try: - response = await model.request(messages, merged_settings, customized_model_request_parameters) + response = await model.request(messages, model_settings, model_request_parameters) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) @@ -105,14 +102,10 @@ async def request_stream( exceptions: list[Exception] = [] for model in self.models: - customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) - merged_settings = merge_model_settings(model.settings, model_settings) async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( - model.request_stream( - messages, merged_settings, customized_model_request_parameters, run_context - ) + model.request_stream(messages, model_settings, model_request_parameters, run_context) ) except Exception as exc: if self._fallback_on(exc): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index eb1a8f966f..c0c0d7d62f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -32,7 +32,7 @@ UserPromptPart, ) from ..profiles import ModelProfile, ModelProfileSpec -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse @@ -125,7 +125,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, @@ -155,7 +158,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index d5e1483ab9..fc0a2973be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -38,7 +38,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -155,7 +155,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) async with self._make_request( messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: @@ -172,7 +175,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) async with self._make_request( messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 869701df70..edaaac6795 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -37,7 +37,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( Model, @@ -225,7 +225,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) @@ -237,6 +240,10 @@ async def count_tokens( model_request_parameters: ModelRequestParameters, ) -> usage.RequestUsage: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) contents, generation_config = await self._build_content_and_config( messages, model_settings, model_request_parameters @@ -292,7 +299,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, True, model_settings, model_request_parameters) yield await self._process_streamed_response(response, model_request_parameters) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 868c2f138d..81907d3944 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -41,7 +41,7 @@ from ..profiles import ModelProfile, ModelProfileSpec from ..profiles.groq import GroqModelProfile from ..providers import Provider, infer_provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( Model, @@ -182,7 +182,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) try: response = await self._completions_create( messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters @@ -219,7 +222,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 0215166514..ed5ecb83be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -38,7 +38,7 @@ ) from ..profiles import ModelProfile, ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( Model, @@ -166,7 +166,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) @@ -182,7 +185,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 9adff38c7e..1ebc7e0dba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -29,7 +29,7 @@ ModelResponse, SystemPromptPart, ) -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse from .wrapper import WrapperModel @@ -352,9 +352,12 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: - model_settings = merge_model_settings(self.settings, model_settings) - with self._instrument(messages, model_settings, model_request_parameters) as finish: - response = await super().request(messages, model_settings, model_request_parameters) + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, + model_request_parameters, + ) + with self._instrument(messages, prepared_settings, prepared_parameters) as finish: + response = await self.wrapped.request(messages, model_settings, model_request_parameters) finish(response) return response @@ -366,11 +369,14 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - model_settings = merge_model_settings(self.settings, model_settings) - with self._instrument(messages, model_settings, model_request_parameters) as finish: + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, + model_request_parameters, + ) + with self._instrument(messages, prepared_settings, prepared_parameters) as finish: response_stream: StreamedResponse | None = None try: - async with super().request_stream( + async with self.wrapped.request_stream( messages, model_settings, model_request_parameters, run_context ) as response_stream: yield response_stream diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index f0c5116c58..7b54ba0f6d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -8,7 +8,7 @@ from .. import _mcp, exceptions from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse if TYPE_CHECKING: @@ -53,7 +53,7 @@ async def request( ) -> ModelResponse: system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages) - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, _ = self.prepare_request(model_settings, model_request_parameters) model_settings = cast(MCPSamplingModelSettings, model_settings or {}) result = await self.session.create_message( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index b320cfc9b0..dec09c0b1c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -38,7 +38,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from ..usage import RequestUsage from . import ( @@ -185,7 +185,10 @@ async def request( ) -> ModelResponse: """Make a non-streaming request to the model from Pydantic AI call.""" check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) @@ -202,7 +205,10 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._stream_completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 53aaf39c78..7b364f9afb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -45,7 +45,7 @@ from ..profiles import ModelProfile, ModelProfileSpec from ..profiles.openai import OpenAIModelProfile, OpenAISystemPromptRole from ..providers import Provider, infer_provider -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -393,7 +393,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -409,7 +412,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -879,7 +885,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) @@ -894,7 +903,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() - model_settings = merge_model_settings(self.settings, model_settings) + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._responses_create( messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index bdbff73ddf..7a01e5324c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -29,7 +29,7 @@ ToolReturnPart, ) from ..profiles import ModelProfileSpec -from ..settings import ModelSettings, merge_model_settings +from ..settings import ModelSettings from ..tools import ToolDefinition from ..usage import RequestUsage from . import Model, ModelRequestParameters, StreamedResponse @@ -110,8 +110,11 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) self.last_model_request_parameters = model_request_parameters - model_settings = merge_model_settings(self.settings, model_settings) model_response = self._request(messages, model_settings, model_request_parameters) model_response.usage = _estimate_usage([*messages, model_response]) return model_response @@ -124,8 +127,11 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) self.last_model_request_parameters = model_request_parameters - model_settings = merge_model_settings(self.settings, model_settings) model_response = self._request(messages, model_settings, model_request_parameters) yield TestStreamedResponse( diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 4c91991cc1..2f64491e74 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -46,6 +46,13 @@ async def request_stream( def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: return self.wrapped.customize_request_parameters(model_request_parameters) + def prepare_request( + self, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters | None, + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + return self.wrapped.prepare_request(model_settings, model_request_parameters) + @property def model_name(self) -> str: return self.wrapped.model_name diff --git a/tests/models/test_model_settings.py b/tests/models/test_model_settings.py index 4db59e741d..3bcbedbafd 100644 --- a/tests/models/test_model_settings.py +++ b/tests/models/test_model_settings.py @@ -2,7 +2,12 @@ from __future__ import annotations -from pydantic_ai import Agent, ModelMessage, ModelResponse, TextPart +import asyncio + +from pydantic_ai import Agent +from pydantic_ai.direct import model_request as direct_model_request +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel @@ -161,3 +166,39 @@ def capture_settings(messages: list[ModelMessage], agent_info: AgentInfo) -> Mod assert captured_settings is not None assert captured_settings.get('temperature') == 0.75 assert len(captured_settings) == 1 # Only one setting should be present + + +def test_direct_model_request_merges_model_settings(): + """Ensure direct requests merge model defaults with provided run settings.""" + + captured_settings = None + + async def capture(messages: list[ModelMessage], agent_info: AgentInfo) -> ModelResponse: + nonlocal captured_settings + captured_settings = agent_info.model_settings + return ModelResponse(parts=[TextPart('ok')]) + + model = FunctionModel( + capture, + settings=ModelSettings(max_tokens=50, temperature=0.3), + ) + + messages: list[ModelMessage] = [ModelRequest.user_text_prompt('hi')] + run_settings = ModelSettings(temperature=0.9, top_p=0.2) + + async def _run() -> ModelResponse: + return await direct_model_request( + model, + messages, + model_settings=run_settings, + model_request_parameters=ModelRequestParameters(), + ) + + response = asyncio.run(_run()) + + assert response.parts == [TextPart('ok')] + assert captured_settings == { + 'max_tokens': 50, + 'temperature': 0.9, + 'top_p': 0.2, + } From eb7d52b4a2e1624a1327ae1035c6b9cfdbe2aa5f Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Wed, 1 Oct 2025 09:35:48 +0200 Subject: [PATCH 8/9] rm none option --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 5 ++--- pydantic_ai_slim/pydantic_ai/models/wrapper.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 07cef3cf8c..22f86d0ca5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -393,7 +393,7 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar def prepare_request( self, model_settings: ModelSettings | None, - model_request_parameters: ModelRequestParameters | None, + model_request_parameters: ModelRequestParameters, ) -> tuple[ModelSettings | None, ModelRequestParameters]: """Prepare request inputs before they are passed to the provider. @@ -404,8 +404,7 @@ def prepare_request( ``self.prepare_request(...)`` at the start of their ``request`` (and related) methods. """ merged_settings = merge_model_settings(self.settings, model_settings) - resolved_parameters = model_request_parameters or ModelRequestParameters() - customized_parameters = self.customize_request_parameters(resolved_parameters) + customized_parameters = self.customize_request_parameters(model_request_parameters) return merged_settings, customized_parameters @property diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 2f64491e74..3260cc7d65 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -49,7 +49,7 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar def prepare_request( self, model_settings: ModelSettings | None, - model_request_parameters: ModelRequestParameters | None, + model_request_parameters: ModelRequestParameters, ) -> tuple[ModelSettings | None, ModelRequestParameters]: return self.wrapped.prepare_request(model_settings, model_request_parameters) From 25a0cd5ac681e67efa7da7acf252bcbf7daff064 Mon Sep 17 00:00:00 2001 From: Moritz Wilksch Date: Wed, 1 Oct 2025 09:41:37 +0200 Subject: [PATCH 9/9] fix imoprts --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 06d20fe9cf..508c88554b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -41,8 +41,7 @@ ) from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec -from ..profiles._json_schema import JsonSchemaTransformer -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage