Skip to content

Commit a426d55

Browse files
Correctly merge Model.settings with model_settings in direct mode (#2980)
1 parent 788938d commit a426d55

19 files changed

+191
-27
lines changed

pydantic_ai_slim/pydantic_ai/direct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def main():
8181
return await model_instance.request(
8282
messages,
8383
model_settings,
84-
model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()),
84+
model_request_parameters or models.ModelRequestParameters(),
8585
)
8686

8787

@@ -193,7 +193,7 @@ async def main():
193193
return model_instance.request_stream(
194194
messages,
195195
model_settings,
196-
model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()),
196+
model_request_parameters or models.ModelRequestParameters(),
197197
)
198198

199199

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from ..output import OutputMode
4343
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
44-
from ..settings import ModelSettings
44+
from ..settings import ModelSettings, merge_model_settings
4545
from ..tools import ToolDefinition
4646
from ..usage import RequestUsage
4747

@@ -390,6 +390,23 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar
390390

391391
return model_request_parameters
392392

393+
def prepare_request(
394+
self,
395+
model_settings: ModelSettings | None,
396+
model_request_parameters: ModelRequestParameters,
397+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
398+
"""Prepare request inputs before they are passed to the provider.
399+
400+
This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures
401+
``customize_request_parameters`` is applied to the resolved
402+
[`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if
403+
they need to customize the preparation flow further, but most implementations should simply call
404+
``self.prepare_request(...)`` at the start of their ``request`` (and related) methods.
405+
"""
406+
merged_settings = merge_model_settings(self.settings, model_settings)
407+
customized_parameters = self.customize_request_parameters(model_request_parameters)
408+
return merged_settings, customized_parameters
409+
393410
@property
394411
@abstractmethod
395412
def model_name(self) -> str:

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ async def request(
205205
model_request_parameters: ModelRequestParameters,
206206
) -> ModelResponse:
207207
check_allow_model_requests()
208+
model_settings, model_request_parameters = self.prepare_request(
209+
model_settings,
210+
model_request_parameters,
211+
)
208212
response = await self._messages_create(
209213
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
210214
)
@@ -220,6 +224,10 @@ async def request_stream(
220224
run_context: RunContext[Any] | None = None,
221225
) -> AsyncIterator[StreamedResponse]:
222226
check_allow_model_requests()
227+
model_settings, model_request_parameters = self.prepare_request(
228+
model_settings,
229+
model_request_parameters,
230+
)
223231
response = await self._messages_create(
224232
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
225233
)

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ async def request(
264264
model_settings: ModelSettings | None,
265265
model_request_parameters: ModelRequestParameters,
266266
) -> ModelResponse:
267+
model_settings, model_request_parameters = self.prepare_request(
268+
model_settings,
269+
model_request_parameters,
270+
)
267271
settings = cast(BedrockModelSettings, model_settings or {})
268272
response = await self._messages_create(messages, False, settings, model_request_parameters)
269273
model_response = await self._process_response(response)
@@ -277,6 +281,10 @@ async def request_stream(
277281
model_request_parameters: ModelRequestParameters,
278282
run_context: RunContext[Any] | None = None,
279283
) -> AsyncIterator[StreamedResponse]:
284+
model_settings, model_request_parameters = self.prepare_request(
285+
model_settings,
286+
model_request_parameters,
287+
)
280288
settings = cast(BedrockModelSettings, model_settings or {})
281289
response = await self._messages_create(messages, True, settings, model_request_parameters)
282290
yield BedrockStreamedResponse(

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ async def request(
165165
model_request_parameters: ModelRequestParameters,
166166
) -> ModelResponse:
167167
check_allow_model_requests()
168+
model_settings, model_request_parameters = self.prepare_request(
169+
model_settings,
170+
model_request_parameters,
171+
)
168172
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
169173
model_response = self._process_response(response)
170174
return model_response

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from pydantic_ai.models.instrumented import InstrumentedModel
1212

1313
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14-
from ..settings import merge_model_settings
1514
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1615

1716
if TYPE_CHECKING:
@@ -78,10 +77,8 @@ async def request(
7877
exceptions: list[Exception] = []
7978

8079
for model in self.models:
81-
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
82-
merged_settings = merge_model_settings(model.settings, model_settings)
8380
try:
84-
response = await model.request(messages, merged_settings, customized_model_request_parameters)
81+
response = await model.request(messages, model_settings, model_request_parameters)
8582
except Exception as exc:
8683
if self._fallback_on(exc):
8784
exceptions.append(exc)
@@ -105,14 +102,10 @@ async def request_stream(
105102
exceptions: list[Exception] = []
106103

107104
for model in self.models:
108-
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
109-
merged_settings = merge_model_settings(model.settings, model_settings)
110105
async with AsyncExitStack() as stack:
111106
try:
112107
response = await stack.enter_async_context(
113-
model.request_stream(
114-
messages, merged_settings, customized_model_request_parameters, run_context
115-
)
108+
model.request_stream(messages, model_settings, model_request_parameters, run_context)
116109
)
117110
except Exception as exc:
118111
if self._fallback_on(exc):

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ async def request(
125125
model_settings: ModelSettings | None,
126126
model_request_parameters: ModelRequestParameters,
127127
) -> ModelResponse:
128+
model_settings, model_request_parameters = self.prepare_request(
129+
model_settings,
130+
model_request_parameters,
131+
)
128132
agent_info = AgentInfo(
129133
function_tools=model_request_parameters.function_tools,
130134
allow_text_output=model_request_parameters.allow_text_output,
@@ -154,6 +158,10 @@ async def request_stream(
154158
model_request_parameters: ModelRequestParameters,
155159
run_context: RunContext[Any] | None = None,
156160
) -> AsyncIterator[StreamedResponse]:
161+
model_settings, model_request_parameters = self.prepare_request(
162+
model_settings,
163+
model_request_parameters,
164+
)
157165
agent_info = AgentInfo(
158166
function_tools=model_request_parameters.function_tools,
159167
allow_text_output=model_request_parameters.allow_text_output,

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ async def request(
155155
model_request_parameters: ModelRequestParameters,
156156
) -> ModelResponse:
157157
check_allow_model_requests()
158+
model_settings, model_request_parameters = self.prepare_request(
159+
model_settings,
160+
model_request_parameters,
161+
)
158162
async with self._make_request(
159163
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
160164
) as http_response:
@@ -171,6 +175,10 @@ async def request_stream(
171175
run_context: RunContext[Any] | None = None,
172176
) -> AsyncIterator[StreamedResponse]:
173177
check_allow_model_requests()
178+
model_settings, model_request_parameters = self.prepare_request(
179+
model_settings,
180+
model_request_parameters,
181+
)
174182
async with self._make_request(
175183
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
176184
) as http_response:

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ async def request(
225225
model_request_parameters: ModelRequestParameters,
226226
) -> ModelResponse:
227227
check_allow_model_requests()
228+
model_settings, model_request_parameters = self.prepare_request(
229+
model_settings,
230+
model_request_parameters,
231+
)
228232
model_settings = cast(GoogleModelSettings, model_settings or {})
229233
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
230234
return self._process_response(response)
@@ -236,6 +240,10 @@ async def count_tokens(
236240
model_request_parameters: ModelRequestParameters,
237241
) -> usage.RequestUsage:
238242
check_allow_model_requests()
243+
model_settings, model_request_parameters = self.prepare_request(
244+
model_settings,
245+
model_request_parameters,
246+
)
239247
model_settings = cast(GoogleModelSettings, model_settings or {})
240248
contents, generation_config = await self._build_content_and_config(
241249
messages, model_settings, model_request_parameters
@@ -291,6 +299,10 @@ async def request_stream(
291299
run_context: RunContext[Any] | None = None,
292300
) -> AsyncIterator[StreamedResponse]:
293301
check_allow_model_requests()
302+
model_settings, model_request_parameters = self.prepare_request(
303+
model_settings,
304+
model_request_parameters,
305+
)
294306
model_settings = cast(GoogleModelSettings, model_settings or {})
295307
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
296308
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ async def request(
182182
model_request_parameters: ModelRequestParameters,
183183
) -> ModelResponse:
184184
check_allow_model_requests()
185+
model_settings, model_request_parameters = self.prepare_request(
186+
model_settings,
187+
model_request_parameters,
188+
)
185189
try:
186190
response = await self._completions_create(
187191
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
@@ -218,6 +222,10 @@ async def request_stream(
218222
run_context: RunContext[Any] | None = None,
219223
) -> AsyncIterator[StreamedResponse]:
220224
check_allow_model_requests()
225+
model_settings, model_request_parameters = self.prepare_request(
226+
model_settings,
227+
model_request_parameters,
228+
)
221229
response = await self._completions_create(
222230
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
223231
)

0 commit comments

Comments
 (0)