Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
return options


def _infer_region() -> str:
def _infer_region(aws_profile: str | None = None) -> str:
"""
Infer the AWS region from the environment variables or
from the boto3 session if available.
Expand All @@ -76,10 +76,19 @@ def _infer_region() -> str:
if aws_region is None:
try:
import boto3
import botocore

session = boto3.Session()
session = boto3.Session(profile_name=aws_profile)
if session.region_name:
aws_region = session.region_name
else:
# If the region is not in the session, it might be in the config file
# but not loaded because AWS_SDK_LOAD_CONFIG is not set.
# creating a client might trigger the loading.
try:
aws_region = session.client("bedrock").meta.region_name
except botocore.exceptions.NoRegionError:
pass
except ImportError:
pass

Expand Down Expand Up @@ -158,11 +167,10 @@ def __init__(
_strict_response_validation: bool = False,
) -> None:
self.aws_secret_key = aws_secret_key

self.aws_access_key = aws_access_key

self.aws_region = _infer_region() if aws_region is None else aws_region
self.aws_profile = aws_profile
self.aws_region = _infer_region(aws_profile) if aws_region is None else aws_region

self.aws_session_token = aws_session_token

Expand Down Expand Up @@ -300,11 +308,10 @@ def __init__(
_strict_response_validation: bool = False,
) -> None:
self.aws_secret_key = aws_secret_key

self.aws_access_key = aws_access_key

self.aws_region = _infer_region() if aws_region is None else aws_region
self.aws_profile = aws_profile
self.aws_region = _infer_region(aws_profile) if aws_region is None else aws_region

self.aws_session_token = aws_session_token

Expand Down
90 changes: 77 additions & 13 deletions src/anthropic/lib/tools/_beta_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self._func_with_validate = pydantic.validate_call(func)
self.name = name or func.__name__
self._defer_loading = defer_loading
self._is_classmethod = False
self._is_staticmethod = False

self.description = description or self._get_description_from_docstring()

Expand All @@ -98,9 +100,25 @@ def __init__(
else:
self.input_schema = self._create_schema_from_function()

@property
def __call__(self) -> CallableT:
return self.func
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)

def __get__(self, instance: Any, owner: Any) -> Any:
if instance is None:
if owner is not None and self._is_classmethod:
instance = owner
else:
return self

tool = self.__class__(
func=self.func.__get__(instance, owner),
name=self.name,
description=self.description,
defer_loading=self._defer_loading,
)
tool._is_classmethod = self._is_classmethod
tool._is_staticmethod = self._is_staticmethod
return tool

def to_dict(self) -> BetaToolParam:
defn: BetaToolParam = {
Expand Down Expand Up @@ -147,6 +165,9 @@ def kw_arguments_schema(
arguments: "list[ArgumentsParameter]",
var_kwargs_schema: CoreSchema | None,
) -> JsonSchemaValue:
# Remove self and cls from the schema as they are handled by the SDK
arguments = [arg for arg in arguments if arg["name"] not in {"self", "cls"}]

schema = super().kw_arguments_schema(arguments, var_kwargs_schema)
if schema.get("type") != "object":
return schema
Expand Down Expand Up @@ -250,17 +271,40 @@ def my_func(x: int) -> str: ...
if _compat.PYDANTIC_V1:
raise RuntimeError("Tool functions are only supported with Pydantic v2")

if func is not None:
is_classmethod = isinstance(func, classmethod)
is_staticmethod = isinstance(func, staticmethod)
if is_classmethod or is_staticmethod:
# We unwrap the classmethod/staticmethod so that we can wrap it
# with our own tool class which implements the descriptor protocol
# correctly.
real_func = func.__func__ # type: ignore
else:
real_func = cast(FunctionT, func)

if real_func is not None:
# @beta_tool called without parentheses
return BetaFunctionTool(
func=func, name=name, description=description, input_schema=input_schema, defer_loading=defer_loading
tool = BetaFunctionTool(
func=real_func, name=name, description=description, input_schema=input_schema, defer_loading=defer_loading
)
tool._is_classmethod = is_classmethod
tool._is_staticmethod = is_staticmethod
return tool # type: ignore

# @beta_tool()
def decorator(func: FunctionT) -> BetaFunctionTool[FunctionT]:
return BetaFunctionTool(
func=func, name=name, description=description, input_schema=input_schema, defer_loading=defer_loading
is_classmethod = isinstance(func, classmethod)
is_staticmethod = isinstance(func, staticmethod)
if is_classmethod or is_staticmethod:
real_func = func.__func__ # type: ignore
else:
real_func = func

tool = BetaFunctionTool(
func=real_func, name=name, description=description, input_schema=input_schema, defer_loading=defer_loading
)
tool._is_classmethod = is_classmethod
tool._is_staticmethod = is_staticmethod
return tool # type: ignore

return decorator

Expand Down Expand Up @@ -314,25 +358,45 @@ async def my_func(x: int) -> str: ...
if _compat.PYDANTIC_V1:
raise RuntimeError("Tool functions are only supported with Pydantic v2")

if func is not None:
is_classmethod = isinstance(func, classmethod)
is_staticmethod = isinstance(func, staticmethod)
if is_classmethod or is_staticmethod:
real_func = func.__func__ # type: ignore
else:
real_func = cast(AsyncFunctionT, func)

if real_func is not None:
# @beta_async_tool called without parentheses
return BetaAsyncFunctionTool(
func=func,
tool = BetaAsyncFunctionTool(
func=real_func,
name=name,
description=description,
input_schema=input_schema,
defer_loading=defer_loading,
)
tool._is_classmethod = is_classmethod
tool._is_staticmethod = is_staticmethod
return tool # type: ignore

# @beta_async_tool()
def decorator(func: AsyncFunctionT) -> BetaAsyncFunctionTool[AsyncFunctionT]:
return BetaAsyncFunctionTool(
func=func,
is_classmethod = isinstance(func, classmethod)
is_staticmethod = isinstance(func, staticmethod)
if is_classmethod or is_staticmethod:
real_func = func.__func__ # type: ignore
else:
real_func = func

tool = BetaAsyncFunctionTool(
func=real_func,
name=name,
description=description,
input_schema=input_schema,
defer_loading=defer_loading,
)
tool._is_classmethod = is_classmethod
tool._is_staticmethod = is_staticmethod
return tool # type: ignore

return decorator

Expand Down