Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def main():
return await model_instance.request(
messages,
model_settings,
model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()),
model_request_parameters or models.ModelRequestParameters(),
)


Expand Down Expand Up @@ -193,7 +193,7 @@ async def main():
return model_instance.request_stream(
messages,
model_settings,
model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()),
model_request_parameters or models.ModelRequestParameters(),
)


Expand Down
20 changes: 19 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ..output import OutputMode
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
from ..profiles._json_schema import JsonSchemaTransformer
from ..settings import ModelSettings
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from ..usage import RequestUsage

Expand Down Expand Up @@ -390,6 +390,24 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar

return model_request_parameters

def prepare_request(
self,
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters | None,
) -> tuple[ModelSettings | None, ModelRequestParameters]:
"""Prepare request inputs before they are passed to the provider.

This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures
``customize_request_parameters`` is applied to the resolved
[`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if
they need to customize the preparation flow further, but most implementations should simply call
``self.prepare_request(...)`` at the start of their ``request`` (and related) methods.
"""
merged_settings = merge_model_settings(self.settings, model_settings)
resolved_parameters = model_request_parameters or ModelRequestParameters()
customized_parameters = self.customize_request_parameters(resolved_parameters)
return merged_settings, customized_parameters

@property
@abstractmethod
def model_name(self) -> str:
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._messages_create(
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
)
Expand All @@ -220,6 +224,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._messages_create(
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
)
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ async def request(
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
settings = cast(BedrockModelSettings, model_settings or {})
response = await self._messages_create(messages, False, settings, model_request_parameters)
model_response = await self._process_response(response)
Expand All @@ -277,6 +281,10 @@ async def request_stream(
model_request_parameters: ModelRequestParameters,
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
settings = cast(BedrockModelSettings, model_settings or {})
response = await self._messages_create(messages, True, settings, model_request_parameters)
yield BedrockStreamedResponse(
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
model_response = self._process_response(response)
return model_response
Expand Down
11 changes: 2 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pydantic_ai.models.instrumented import InstrumentedModel

from ..exceptions import FallbackExceptionGroup, ModelHTTPError
from ..settings import merge_model_settings
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

if TYPE_CHECKING:
Expand Down Expand Up @@ -78,10 +77,8 @@ async def request(
exceptions: list[Exception] = []

for model in self.models:
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
merged_settings = merge_model_settings(model.settings, model_settings)
try:
response = await model.request(messages, merged_settings, customized_model_request_parameters)
response = await model.request(messages, model_settings, model_request_parameters)
except Exception as exc:
if self._fallback_on(exc):
exceptions.append(exc)
Expand All @@ -105,14 +102,10 @@ async def request_stream(
exceptions: list[Exception] = []

for model in self.models:
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
merged_settings = merge_model_settings(model.settings, model_settings)
async with AsyncExitStack() as stack:
try:
response = await stack.enter_async_context(
model.request_stream(
messages, merged_settings, customized_model_request_parameters, run_context
)
model.request_stream(messages, model_settings, model_request_parameters, run_context)
)
except Exception as exc:
if self._fallback_on(exc):
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ async def request(
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
agent_info = AgentInfo(
function_tools=model_request_parameters.function_tools,
allow_text_output=model_request_parameters.allow_text_output,
Expand Down Expand Up @@ -154,6 +158,10 @@ async def request_stream(
model_request_parameters: ModelRequestParameters,
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
agent_info = AgentInfo(
function_tools=model_request_parameters.function_tools,
allow_text_output=model_request_parameters.allow_text_output,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
async with self._make_request(
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
) as http_response:
Expand All @@ -171,6 +175,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
async with self._make_request(
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
) as http_response:
Expand Down
12 changes: 12 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
model_settings = cast(GoogleModelSettings, model_settings or {})
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
return self._process_response(response)
Expand All @@ -236,6 +240,10 @@ async def count_tokens(
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
model_settings = cast(GoogleModelSettings, model_settings or {})
contents, generation_config = await self._build_content_and_config(
messages, model_settings, model_request_parameters
Expand Down Expand Up @@ -291,6 +299,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
model_settings = cast(GoogleModelSettings, model_settings or {})
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
try:
response = await self._completions_create(
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
Expand Down Expand Up @@ -218,6 +222,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
)
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
)
Expand All @@ -181,6 +185,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
)
Expand Down
16 changes: 12 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,12 @@ async def request(
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
with self._instrument(messages, model_settings, model_request_parameters) as finish:
response = await super().request(messages, model_settings, model_request_parameters)
prepared_settings, prepared_parameters = self.wrapped.prepare_request(
model_settings,
model_request_parameters,
)
with self._instrument(messages, prepared_settings, prepared_parameters) as finish:
response = await self.wrapped.request(messages, model_settings, model_request_parameters)
finish(response)
return response

Expand All @@ -365,10 +369,14 @@ async def request_stream(
model_request_parameters: ModelRequestParameters,
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
with self._instrument(messages, model_settings, model_request_parameters) as finish:
prepared_settings, prepared_parameters = self.wrapped.prepare_request(
model_settings,
model_request_parameters,
)
with self._instrument(messages, prepared_settings, prepared_parameters) as finish:
response_stream: StreamedResponse | None = None
try:
async with super().request_stream(
async with self.wrapped.request_stream(
messages, model_settings, model_request_parameters, run_context
) as response_stream:
yield response_stream
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages)

model_settings, _ = self.prepare_request(model_settings, model_request_parameters)
model_settings = cast(MCPSamplingModelSettings, model_settings or {})

result = await self.session.create_message(
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ async def request(
) -> ModelResponse:
"""Make a non-streaming request to the model from Pydantic AI call."""
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
)
Expand All @@ -201,6 +205,10 @@ async def request_stream(
) -> AsyncIterator[StreamedResponse]:
"""Make a streaming request to the model from Pydantic AI call."""
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._stream_completions_create(
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
)
Expand Down
16 changes: 16 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters
)
Expand All @@ -408,6 +412,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters
)
Expand Down Expand Up @@ -877,6 +885,10 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._responses_create(
messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
Expand All @@ -891,6 +903,10 @@ async def request_stream(
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._responses_create(
messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ async def request(
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
self.last_model_request_parameters = model_request_parameters
model_response = self._request(messages, model_settings, model_request_parameters)
model_response.usage = _estimate_usage([*messages, model_response])
Expand All @@ -123,6 +127,10 @@ async def request_stream(
model_request_parameters: ModelRequestParameters,
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
self.last_model_request_parameters = model_request_parameters

model_response = self._request(messages, model_settings, model_request_parameters)
Expand Down
Loading