diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index ac1503db1b..0459af9376 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -37,12 +37,13 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider +from ..providers.anthropic import AsyncAnthropicClient from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent try: - from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream + from anthropic import NOT_GIVEN, APIStatusError, AsyncStream from anthropic.types.beta import ( BetaBase64PDFBlockParam, BetaBase64PDFSourceParam, @@ -134,16 +135,16 @@ class AnthropicModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - client: AsyncAnthropic = field(repr=False) + client: AsyncAnthropicClient = field(repr=False) _model_name: AnthropicModelName = field(repr=False) - _provider: Provider[AsyncAnthropic] = field(repr=False) + _provider: Provider[AsyncAnthropicClient] = field(repr=False) def __init__( self, model_name: AnthropicModelName, *, - provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic', + provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): @@ -153,7 +154,7 @@ def __init__( model_name: The name of the Anthropic model to use. List of model names available [here](https://docs.anthropic.com/en/docs/about-claude/models). provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an - instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used. + instance of `Provider[AsyncAnthropicClient]`. If not provided, the other parameters will be used. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. settings: Default model settings for this model instance. """ diff --git a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py index 20bc3255ee..b596c4d7eb 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py @@ -1,9 +1,10 @@ from __future__ import annotations as _annotations import os -from typing import overload +from typing import Union, overload import httpx +from typing_extensions import TypeAlias from pydantic_ai.exceptions import UserError from pydantic_ai.models import cached_async_http_client @@ -12,15 +13,18 @@ from pydantic_ai.providers import Provider try: - from anthropic import AsyncAnthropic -except ImportError as _import_error: # pragma: no cover + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock +except ImportError as _import_error: raise ImportError( 'Please install the `anthropic` package to use the Anthropic provider, ' 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`' ) from _import_error -class AnthropicProvider(Provider[AsyncAnthropic]): +AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock] + + +class AnthropicProvider(Provider[AsyncAnthropicClient]): """Provider for Anthropic API.""" @property @@ -32,14 +36,14 @@ def base_url(self) -> str: return str(self._client.base_url) @property - def client(self) -> AsyncAnthropic: + def client(self) -> AsyncAnthropicClient: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return anthropic_model_profile(model_name) @overload - def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ... + def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ... @overload def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ... @@ -48,7 +52,7 @@ def __init__( self, *, api_key: str | None = None, - anthropic_client: AsyncAnthropic | None = None, + anthropic_client: AsyncAnthropicClient | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Anthropic provider. @@ -71,7 +75,6 @@ def __init__( 'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`' 'to use the Anthropic provider.' ) - if http_client is not None: self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) else: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index a32d08d5d3..38f8adb1b4 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -97,6 +97,7 @@ def test_init(): m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar')) + assert isinstance(m.client, AsyncAnthropic) assert m.client.api_key == 'foobar' assert m.model_name == 'claude-3-5-haiku-latest' assert m.system == 'anthropic' diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 44f47554bd..79871a36c5 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -1,14 +1,11 @@ from __future__ import annotations as _annotations -import httpx import pytest -from pydantic_ai.exceptions import UserError - -from ..conftest import TestEnv, try_import +from ..conftest import try_import with try_import() as imports_successful: - from anthropic import AsyncAnthropic + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock from pydantic_ai.providers.anthropic import AnthropicProvider @@ -24,24 +21,19 @@ def test_anthropic_provider(): assert provider.client.api_key == 'api-key' -def test_anthropic_provider_need_api_key(env: TestEnv) -> None: - env.remove('ANTHROPIC_API_KEY') - with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'): - AnthropicProvider() - - -def test_anthropic_provider_pass_http_client() -> None: - http_client = httpx.AsyncClient() - provider = AnthropicProvider(http_client=http_client, api_key='api-key') - assert isinstance(provider.client, AsyncAnthropic) - # Verify the http_client is being used by the AsyncAnthropic client - assert provider.client._client == http_client # type: ignore[reportPrivateUsage] - - def test_anthropic_provider_pass_anthropic_client() -> None: anthropic_client = AsyncAnthropic(api_key='api-key') provider = AnthropicProvider(anthropic_client=anthropic_client) assert provider.client == anthropic_client + bedrock_client = AsyncAnthropicBedrock( + aws_secret_key='aws-secret-key', + aws_access_key='aws-access-key', + aws_region='us-west-2', + aws_profile='default', + aws_session_token='aws-session-token', + ) + provider = AnthropicProvider(anthropic_client=bedrock_client) + assert provider.client == bedrock_client def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None: