Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/aws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ Alternatively, set the `AWS_BEARER_TOKEN_BEDROCK` environment variable locally f
```python
from langchain_aws import ChatBedrockConverse

model = ChatBedrockConverse()
model = ChatBedrockConverse(
model="us.anthropic.claude-sonnet-4-20250514-v1:0",
)
model.invoke("Sing a ballad of LangChain.")
```

Expand Down
93 changes: 16 additions & 77 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator

from langchain_aws.chat_models._compat import _convert_from_v1_to_anthropic
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.function_calling import (
AnthropicTool,
ToolsOutputParser,
Expand Down Expand Up @@ -349,7 +348,7 @@ def _format_image(image_url: str) -> Dict:


def _format_data_content_block(block: dict) -> dict:
"""Format standard data content block to format expected by Converse API."""
"""Format standard data content block to format expected by Bedrock."""
if block["type"] == "image":
if "base64" in block or block.get("source_type") == "base64":
if "mime_type" not in block:
Expand Down Expand Up @@ -760,9 +759,6 @@ class ChatBedrock(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""

system_prompt_with_tools: str = ""
beta_use_converse_api: bool = False
"""Use the new Bedrock `converse` API which provides a standardized interface to
all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more."""

stop_sequences: Optional[List[str]] = Field(default=None, alias="stop")
"""Stop sequence inference parameter from new Bedrock `converse` API providing
Expand All @@ -789,18 +785,24 @@ def get_lc_namespace(cls) -> List[str]:

@model_validator(mode="before")
@classmethod
def set_beta_use_converse_api(cls, values: Dict) -> Any:
def check_unsupported_model(cls, values: Dict) -> Any:
model_id = values.get("model_id", values.get("model"))
base_model_id = values.get("base_model_id", values.get("base_model", ""))

if not model_id or "beta_use_converse_api" in values:
if not model_id:
return values

nova_id = "amazon.nova"
values["beta_use_converse_api"] = False
# Add new Bedrock models here as needed
unsupported_models = ["amazon.nova"]

if nova_id in model_id or nova_id in base_model_id:
values["beta_use_converse_api"] = True
bad_model_err = (
"Provided model is unsupported on ChatBedrock with langchain-aws>=1.0.0."
" Please use ChatBedrockConverse instead."
)
if any(
model in model_id or model in base_model_id for model in unsupported_models
):
raise ValueError(bad_model_err)
elif not base_model_id and "application-inference-profile" in model_id:
bedrock_client = values.get("bedrock_client")
if not bedrock_client:
Expand All @@ -820,7 +822,9 @@ def set_beta_use_converse_api(cls, values: Dict) -> Any:
if "models" in response and len(response["models"]) > 0:
model_arn = response["models"][0]["modelArn"]
resolved_base_model = model_arn.split("/")[-1]
values["beta_use_converse_api"] = "nova" in resolved_base_model
values["base_model_id"] = resolved_base_model
if any(model in resolved_base_model for model in unsupported_models):
raise ValueError(bad_model_err)
return values

@model_validator(mode="before")
Expand Down Expand Up @@ -879,11 +883,6 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if self.beta_use_converse_api:
yield from self._as_converse._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
provider = self._get_provider()
prompt: Optional[str] = None
system: Optional[str] = None
Expand Down Expand Up @@ -994,16 +993,6 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.beta_use_converse_api:
if not self.streaming:
return self._as_converse._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
stream_iter = self._as_converse._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
completion = ""
llm_output: Dict[str, Any] = {}
tool_calls: List[ToolCall] = []
Expand Down Expand Up @@ -1219,12 +1208,6 @@ def bind_tools(
[Runnable][langchain_core.runnables.Runnable] constructor.

"""
if self.beta_use_converse_api:
if isinstance(tool_choice, bool):
tool_choice = "any" if tool_choice else None
return self._as_converse.bind_tools(
tools, tool_choice=tool_choice, **kwargs
)
if self._get_provider() == "anthropic":
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]

Expand Down Expand Up @@ -1397,10 +1380,6 @@ class AnswerWithJustification(BaseModel):
```

""" # noqa: E501
if self.beta_use_converse_api:
return self._as_converse.with_structured_output(
schema, include_raw=include_raw, **kwargs
)
if "claude-" not in self._get_base_model():
raise ValueError(
f"Structured output is not supported for model {self._get_base_model()}"
Expand Down Expand Up @@ -1443,43 +1422,3 @@ class AnswerWithJustification(BaseModel):
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

@property
def _as_converse(self) -> ChatBedrockConverse:
kwargs = {
k: v
for k, v in (self.model_kwargs or {}).items()
if k
in (
"stop",
"stop_sequences",
"max_tokens",
"temperature",
"top_p",
"additional_model_request_fields",
"additional_model_response_field_paths",
"performance_config",
"request_metadata",
)
}
if self.max_tokens:
kwargs["max_tokens"] = self.max_tokens
if self.temperature is not None:
kwargs["temperature"] = self.temperature
if self.stop_sequences:
kwargs["stop_sequences"] = self.stop_sequences

return ChatBedrockConverse(
client=self.client,
model=self.model_id,
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
config=self.config,
provider=self.provider or "",
base_url=self.endpoint_url,
guardrail_config=(self.guardrails if self._guardrails_enabled else None), # type: ignore[call-arg]
**kwargs,
)
7 changes: 2 additions & 5 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ def test_guardrails() -> None:
"guardrailVersion": "1",
"trace": "enabled",
},
"beta_use_converse_api": True,
}
chat_model = ChatBedrock(**params) # type: ignore[arg-type]
messages = [
Expand Down Expand Up @@ -841,8 +840,7 @@ def test_guardrails_streaming_trace() -> None:
model_kwargs={"temperature": 0},
guardrails=guardrail_config,
callbacks=[guardrail_callback],
region="us-west-2",
beta_use_converse_api=False, # Use legacy API for this test
region_name="us-west-2",
) # type: ignore[call-arg]

# Test message that should trigger guardrail intervention
Expand All @@ -855,8 +853,7 @@ def test_guardrails_streaming_trace() -> None:
model_kwargs={"temperature": 0},
guardrails=guardrail_config,
callbacks=[invoke_callback],
region="us-west-2",
beta_use_converse_api=False,
region_name="us-west-2",
) # type: ignore[call-arg]

try:
Expand Down
28 changes: 0 additions & 28 deletions libs/aws/tests/integration_tests/chat_models/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,3 @@ def supports_image_inputs(self) -> bool:
@pytest.mark.xfail(reason="Not implemented.")
def test_double_messages_conversation(self, model: BaseChatModel) -> None:
super().test_double_messages_conversation(model)


class TestBedrockUseConverseStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatBedrock

@property
def chat_model_params(self) -> dict:
return {
"model_id": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"beta_use_converse_api": True,
}

@property
def standard_chat_model_params(self) -> dict:
return {
"temperature": 0,
"max_tokens": 100,
"stop_sequences": [],
"model_kwargs": {
"stop": [],
},
}

@property
def supports_image_inputs(self) -> bool:
return True
43 changes: 0 additions & 43 deletions libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr
Original file line number Diff line number Diff line change
@@ -1,47 +1,4 @@
# serializer version: 1
# name: TestBedrockAsConverseStandard.test_serdes[serialized]
dict({
'id': list([
'langchain',
'chat_models',
'bedrock',
'ChatBedrock',
]),
'kwargs': dict({
'beta_use_converse_api': True,
'guardrails': dict({
'guardrailIdentifier': None,
'guardrailVersion': None,
'trace': None,
}),
'max_tokens': 100,
'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0',
'model_kwargs': dict({
'stop': list([
]),
}),
'provider_stop_reason_key_map': dict({
'ai21': 'finishReason',
'amazon': 'completionReason',
'anthropic': 'stop_reason',
'cohere': 'finish_reason',
'mistral': 'stop_reason',
}),
'provider_stop_sequence_key_name_map': dict({
'ai21': 'stop_sequences',
'amazon': 'stopSequences',
'anthropic': 'stop_sequences',
'cohere': 'stop_sequences',
'mistral': 'stop_sequences',
}),
'region_name': 'us-east-1',
'temperature': 0,
}),
'lc': 1,
'name': 'ChatBedrock',
'type': 'constructor',
})
# ---
# name: TestBedrockStandard.test_serdes[serialized]
dict({
'id': list([
Expand Down
Loading