Skip to content

Commit 8c436af

Browse files
committed
add Model.prepare_request to merge + customize
1 parent fea3f51 commit 8c436af

File tree

18 files changed

+194
-56
lines changed

18 files changed

+194
-56
lines changed

pydantic_ai_slim/pydantic_ai/direct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def main():
8181
return await model_instance.request(
8282
messages,
8383
model_settings,
84-
model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()),
84+
model_request_parameters or models.ModelRequestParameters(),
8585
)
8686

8787

@@ -193,7 +193,7 @@ async def main():
193193
return model_instance.request_stream(
194194
messages,
195195
model_settings,
196-
model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()),
196+
model_request_parameters or models.ModelRequestParameters(),
197197
)
198198

199199

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ..output import OutputMode
4242
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
4343
from ..profiles._json_schema import JsonSchemaTransformer
44-
from ..settings import ModelSettings
44+
from ..settings import ModelSettings, merge_model_settings
4545
from ..tools import ToolDefinition
4646
from ..usage import RequestUsage
4747

@@ -390,6 +390,24 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar
390390

391391
return model_request_parameters
392392

393+
def prepare_request(
394+
self,
395+
model_settings: ModelSettings | None,
396+
model_request_parameters: ModelRequestParameters | None,
397+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
398+
"""Prepare request inputs before they are passed to the provider.
399+
400+
This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures
401+
``customize_request_parameters`` is applied to the resolved
402+
[`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if
403+
they need to customize the preparation flow further, but most implementations should simply call
404+
``self.prepare_request(...)`` at the start of their ``request`` (and related) methods.
405+
"""
406+
merged_settings = merge_model_settings(self.settings, model_settings)
407+
resolved_parameters = model_request_parameters or ModelRequestParameters()
408+
customized_parameters = self.customize_request_parameters(resolved_parameters)
409+
return merged_settings, customized_parameters
410+
393411
@property
394412
@abstractmethod
395413
def model_name(self) -> str:

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ..profiles import ModelProfileSpec
4040
from ..providers import Provider, infer_provider
4141
from ..providers.anthropic import AsyncAnthropicClient
42-
from ..settings import ModelSettings, merge_model_settings
42+
from ..settings import ModelSettings
4343
from ..tools import ToolDefinition
4444
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
4545

@@ -205,7 +205,10 @@ async def request(
205205
model_request_parameters: ModelRequestParameters,
206206
) -> ModelResponse:
207207
check_allow_model_requests()
208-
model_settings = merge_model_settings(self.settings, model_settings)
208+
model_settings, model_request_parameters = self.prepare_request(
209+
model_settings,
210+
model_request_parameters,
211+
)
209212
response = await self._messages_create(
210213
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
211214
)
@@ -221,7 +224,10 @@ async def request_stream(
221224
run_context: RunContext[Any] | None = None,
222225
) -> AsyncIterator[StreamedResponse]:
223226
check_allow_model_requests()
224-
model_settings = merge_model_settings(self.settings, model_settings)
227+
model_settings, model_request_parameters = self.prepare_request(
228+
model_settings,
229+
model_request_parameters,
230+
)
225231
response = await self._messages_create(
226232
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
227233
)

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
4444
from pydantic_ai.providers import Provider, infer_provider
4545
from pydantic_ai.providers.bedrock import BedrockModelProfile
46-
from pydantic_ai.settings import ModelSettings, merge_model_settings
46+
from pydantic_ai.settings import ModelSettings
4747
from pydantic_ai.tools import ToolDefinition
4848

4949
if TYPE_CHECKING:
@@ -264,7 +264,10 @@ async def request(
264264
model_settings: ModelSettings | None,
265265
model_request_parameters: ModelRequestParameters,
266266
) -> ModelResponse:
267-
model_settings = merge_model_settings(self.settings, model_settings)
267+
model_settings, model_request_parameters = self.prepare_request(
268+
model_settings,
269+
model_request_parameters,
270+
)
268271
settings = cast(BedrockModelSettings, model_settings or {})
269272
response = await self._messages_create(messages, False, settings, model_request_parameters)
270273
model_response = await self._process_response(response)
@@ -278,7 +281,10 @@ async def request_stream(
278281
model_request_parameters: ModelRequestParameters,
279282
run_context: RunContext[Any] | None = None,
280283
) -> AsyncIterator[StreamedResponse]:
281-
model_settings = merge_model_settings(self.settings, model_settings)
284+
model_settings, model_request_parameters = self.prepare_request(
285+
model_settings,
286+
model_request_parameters,
287+
)
282288
settings = cast(BedrockModelSettings, model_settings or {})
283289
response = await self._messages_create(messages, True, settings, model_request_parameters)
284290
yield BedrockStreamedResponse(

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from ..profiles import ModelProfileSpec
3030
from ..providers import Provider, infer_provider
31-
from ..settings import ModelSettings, merge_model_settings
31+
from ..settings import ModelSettings
3232
from ..tools import ToolDefinition
3333
from . import Model, ModelRequestParameters, check_allow_model_requests
3434

@@ -165,7 +165,10 @@ async def request(
165165
model_request_parameters: ModelRequestParameters,
166166
) -> ModelResponse:
167167
check_allow_model_requests()
168-
model_settings = merge_model_settings(self.settings, model_settings)
168+
model_settings, model_request_parameters = self.prepare_request(
169+
model_settings,
170+
model_request_parameters,
171+
)
169172
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
170173
model_response = self._process_response(response)
171174
return model_response

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from pydantic_ai.models.instrumented import InstrumentedModel
1212

1313
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14-
from ..settings import merge_model_settings
1514
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1615

1716
if TYPE_CHECKING:
@@ -78,10 +77,8 @@ async def request(
7877
exceptions: list[Exception] = []
7978

8079
for model in self.models:
81-
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
82-
merged_settings = merge_model_settings(model.settings, model_settings)
8380
try:
84-
response = await model.request(messages, merged_settings, customized_model_request_parameters)
81+
response = await model.request(messages, model_settings, model_request_parameters)
8582
except Exception as exc:
8683
if self._fallback_on(exc):
8784
exceptions.append(exc)
@@ -105,14 +102,10 @@ async def request_stream(
105102
exceptions: list[Exception] = []
106103

107104
for model in self.models:
108-
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
109-
merged_settings = merge_model_settings(model.settings, model_settings)
110105
async with AsyncExitStack() as stack:
111106
try:
112107
response = await stack.enter_async_context(
113-
model.request_stream(
114-
messages, merged_settings, customized_model_request_parameters, run_context
115-
)
108+
model.request_stream(messages, model_settings, model_request_parameters, run_context)
116109
)
117110
except Exception as exc:
118111
if self._fallback_on(exc):

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
UserPromptPart,
3333
)
3434
from ..profiles import ModelProfile, ModelProfileSpec
35-
from ..settings import ModelSettings, merge_model_settings
35+
from ..settings import ModelSettings
3636
from ..tools import ToolDefinition
3737
from . import Model, ModelRequestParameters, StreamedResponse
3838

@@ -125,7 +125,10 @@ async def request(
125125
model_settings: ModelSettings | None,
126126
model_request_parameters: ModelRequestParameters,
127127
) -> ModelResponse:
128-
model_settings = merge_model_settings(self.settings, model_settings)
128+
model_settings, model_request_parameters = self.prepare_request(
129+
model_settings,
130+
model_request_parameters,
131+
)
129132
agent_info = AgentInfo(
130133
function_tools=model_request_parameters.function_tools,
131134
allow_text_output=model_request_parameters.allow_text_output,
@@ -155,7 +158,10 @@ async def request_stream(
155158
model_request_parameters: ModelRequestParameters,
156159
run_context: RunContext[Any] | None = None,
157160
) -> AsyncIterator[StreamedResponse]:
158-
model_settings = merge_model_settings(self.settings, model_settings)
161+
model_settings, model_request_parameters = self.prepare_request(
162+
model_settings,
163+
model_request_parameters,
164+
)
159165
agent_info = AgentInfo(
160166
function_tools=model_request_parameters.function_tools,
161167
allow_text_output=model_request_parameters.allow_text_output,

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from ..profiles import ModelProfileSpec
4040
from ..providers import Provider, infer_provider
41-
from ..settings import ModelSettings, merge_model_settings
41+
from ..settings import ModelSettings
4242
from ..tools import ToolDefinition
4343
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
4444

@@ -155,7 +155,10 @@ async def request(
155155
model_request_parameters: ModelRequestParameters,
156156
) -> ModelResponse:
157157
check_allow_model_requests()
158-
model_settings = merge_model_settings(self.settings, model_settings)
158+
model_settings, model_request_parameters = self.prepare_request(
159+
model_settings,
160+
model_request_parameters,
161+
)
159162
async with self._make_request(
160163
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
161164
) as http_response:
@@ -172,7 +175,10 @@ async def request_stream(
172175
run_context: RunContext[Any] | None = None,
173176
) -> AsyncIterator[StreamedResponse]:
174177
check_allow_model_requests()
175-
model_settings = merge_model_settings(self.settings, model_settings)
178+
model_settings, model_request_parameters = self.prepare_request(
179+
model_settings,
180+
model_request_parameters,
181+
)
176182
async with self._make_request(
177183
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
178184
) as http_response:

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from ..profiles import ModelProfileSpec
3939
from ..providers import Provider
40-
from ..settings import ModelSettings, merge_model_settings
40+
from ..settings import ModelSettings
4141
from ..tools import ToolDefinition
4242
from . import (
4343
Model,
@@ -225,7 +225,10 @@ async def request(
225225
model_request_parameters: ModelRequestParameters,
226226
) -> ModelResponse:
227227
check_allow_model_requests()
228-
model_settings = merge_model_settings(self.settings, model_settings)
228+
model_settings, model_request_parameters = self.prepare_request(
229+
model_settings,
230+
model_request_parameters,
231+
)
229232
model_settings = cast(GoogleModelSettings, model_settings or {})
230233
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
231234
return self._process_response(response)
@@ -237,6 +240,10 @@ async def count_tokens(
237240
model_request_parameters: ModelRequestParameters,
238241
) -> usage.RequestUsage:
239242
check_allow_model_requests()
243+
model_settings, model_request_parameters = self.prepare_request(
244+
model_settings,
245+
model_request_parameters,
246+
)
240247
model_settings = cast(GoogleModelSettings, model_settings or {})
241248
contents, generation_config = await self._build_content_and_config(
242249
messages, model_settings, model_request_parameters
@@ -292,7 +299,10 @@ async def request_stream(
292299
run_context: RunContext[Any] | None = None,
293300
) -> AsyncIterator[StreamedResponse]:
294301
check_allow_model_requests()
295-
model_settings = merge_model_settings(self.settings, model_settings)
302+
model_settings, model_request_parameters = self.prepare_request(
303+
model_settings,
304+
model_request_parameters,
305+
)
296306
model_settings = cast(GoogleModelSettings, model_settings or {})
297307
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
298308
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ..profiles import ModelProfile, ModelProfileSpec
4242
from ..profiles.groq import GroqModelProfile
4343
from ..providers import Provider, infer_provider
44-
from ..settings import ModelSettings, merge_model_settings
44+
from ..settings import ModelSettings
4545
from ..tools import ToolDefinition
4646
from . import (
4747
Model,
@@ -182,7 +182,10 @@ async def request(
182182
model_request_parameters: ModelRequestParameters,
183183
) -> ModelResponse:
184184
check_allow_model_requests()
185-
model_settings = merge_model_settings(self.settings, model_settings)
185+
model_settings, model_request_parameters = self.prepare_request(
186+
model_settings,
187+
model_request_parameters,
188+
)
186189
try:
187190
response = await self._completions_create(
188191
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
@@ -219,7 +222,10 @@ async def request_stream(
219222
run_context: RunContext[Any] | None = None,
220223
) -> AsyncIterator[StreamedResponse]:
221224
check_allow_model_requests()
222-
model_settings = merge_model_settings(self.settings, model_settings)
225+
model_settings, model_request_parameters = self.prepare_request(
226+
model_settings,
227+
model_request_parameters,
228+
)
223229
response = await self._completions_create(
224230
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
225231
)

0 commit comments

Comments
 (0)