Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
)
from ..profiles import ModelProfileSpec
from ..providers import Provider, infer_provider
from ..providers.anthropic import AsyncAnthropicClient
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent

try:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
from anthropic.types.beta import (
BetaBase64PDFBlockParam,
BetaBase64PDFSourceParam,
Expand Down Expand Up @@ -134,16 +135,16 @@ class AnthropicModel(Model):
Apart from `__init__`, all methods are private or match those of the base class.
"""

client: AsyncAnthropic = field(repr=False)
client: AsyncAnthropicClient = field(repr=False)

_model_name: AnthropicModelName = field(repr=False)
_provider: Provider[AsyncAnthropic] = field(repr=False)
_provider: Provider[AsyncAnthropicClient] = field(repr=False)

def __init__(
self,
model_name: AnthropicModelName,
*,
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
profile: ModelProfileSpec | None = None,
settings: ModelSettings | None = None,
):
Expand All @@ -153,7 +154,7 @@ def __init__(
model_name: The name of the Anthropic model to use. List of model names available
[here](https://docs.anthropic.com/en/docs/about-claude/models).
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
instance of `Provider[AsyncAnthropicClient]`. If not provided, the other parameters will be used.
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Default model settings for this model instance.
"""
Expand Down
19 changes: 11 additions & 8 deletions pydantic_ai_slim/pydantic_ai/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations as _annotations

import os
from typing import overload
from typing import Union, overload

import httpx
from typing_extensions import TypeAlias

from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
Expand All @@ -12,15 +13,18 @@
from pydantic_ai.providers import Provider

try:
from anthropic import AsyncAnthropic
except ImportError as _import_error: # pragma: no cover
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
except ImportError as _import_error:
raise ImportError(
'Please install the `anthropic` package to use the Anthropic provider, '
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
) from _import_error


class AnthropicProvider(Provider[AsyncAnthropic]):
AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]


class AnthropicProvider(Provider[AsyncAnthropicClient]):
"""Provider for Anthropic API."""

@property
Expand All @@ -32,14 +36,14 @@ def base_url(self) -> str:
return str(self._client.base_url)

@property
def client(self) -> AsyncAnthropic:
def client(self) -> AsyncAnthropicClient:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
return anthropic_model_profile(model_name)

@overload
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...

@overload
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
Expand All @@ -48,7 +52,7 @@ def __init__(
self,
*,
api_key: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
anthropic_client: AsyncAnthropicClient | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Anthropic provider.
Expand All @@ -71,7 +75,6 @@ def __init__(
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
'to use the Anthropic provider.'
)

if http_client is not None:
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
else:
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@

def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
assert isinstance(m.client, AsyncAnthropic)
assert m.client.api_key == 'foobar'
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'
Expand Down
30 changes: 11 additions & 19 deletions tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from __future__ import annotations as _annotations

import httpx
import pytest

from pydantic_ai.exceptions import UserError

from ..conftest import TestEnv, try_import
from ..conftest import try_import

with try_import() as imports_successful:
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock

from pydantic_ai.providers.anthropic import AnthropicProvider

Expand All @@ -24,24 +21,19 @@ def test_anthropic_provider():
assert provider.client.api_key == 'api-key'


def test_anthropic_provider_need_api_key(env: TestEnv) -> None:
env.remove('ANTHROPIC_API_KEY')
with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'):
AnthropicProvider()


def test_anthropic_provider_pass_http_client() -> None:
http_client = httpx.AsyncClient()
provider = AnthropicProvider(http_client=http_client, api_key='api-key')
assert isinstance(provider.client, AsyncAnthropic)
# Verify the http_client is being used by the AsyncAnthropic client
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]


def test_anthropic_provider_pass_anthropic_client() -> None:
anthropic_client = AsyncAnthropic(api_key='api-key')
provider = AnthropicProvider(anthropic_client=anthropic_client)
assert provider.client == anthropic_client
bedrock_client = AsyncAnthropicBedrock(
aws_secret_key='aws-secret-key',
aws_access_key='aws-access-key',
aws_region='us-west-2',
aws_profile='default',
aws_session_token='aws-session-token',
)
provider = AnthropicProvider(anthropic_client=bedrock_client)
assert provider.client == bedrock_client


def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down