From 33c07a28bfd599fb79efd689a04204d0732f7ce5 Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Fri, 19 Sep 2025 21:27:13 -0700 Subject: [PATCH 1/2] fix(aws): (v1) Remove beta_use_converse_api from ChatBedrock --- README.md | 7 +- libs/aws/README.md | 8 +- libs/aws/langchain_aws/chat_models/bedrock.py | 93 +++--------------- .../chat_models/test_bedrock.py | 3 - .../chat_models/test_standard.py | 28 ------ .../__snapshots__/test_standard.ambr | 43 --------- .../unit_tests/chat_models/test_bedrock.py | 96 +++++++------------ libs/aws/tests/unit_tests/test_standard.py | 28 ------ 8 files changed, 56 insertions(+), 250 deletions(-) diff --git a/README.md b/README.md index b2805200e..8ae6d9e88 100644 --- a/README.md +++ b/README.md @@ -44,12 +44,11 @@ pip install langgraph-checkpoint-aws Here's a simple example of how to use the `langchain-aws` package. ```python -from langchain_aws import ChatBedrock +from langchain_aws import ChatBedrockConverse # Initialize the Bedrock chat model -llm = ChatBedrock( - model="anthropic.claude-3-sonnet-20240229-v1:0", - beta_use_converse_api=True +llm = ChatBedrockConverse( + model="us.anthropic.claude-sonnet-4-20250514-v1:0", ) # Invoke the llm diff --git a/libs/aws/README.md b/libs/aws/README.md index ef52fd379..bf9356225 100644 --- a/libs/aws/README.md +++ b/libs/aws/README.md @@ -16,12 +16,14 @@ Alternatively, set the `AWS_BEARER_TOKEN_BEDROCK` environment variable locally f ## Chat Models -`ChatBedrock` class exposes chat models from Bedrock. +`ChatBedrockConverse` class exposes chat models from Bedrock. ```python -from langchain_aws import ChatBedrock +from langchain_aws import ChatBedrockConverse -llm = ChatBedrock() +llm = ChatBedrockConverse( + model="us.anthropic.claude-sonnet-4-20250514-v1:0", +) llm.invoke("Sing a ballad of LangChain.") ``` diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index e809472cc..ca409453c 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -45,7 +45,6 @@ from langchain_core.utils.utils import _build_model_kwargs from pydantic import BaseModel, ConfigDict, Field, model_validator -from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse from langchain_aws.function_calling import ( ToolsOutputParser, _lc_tool_calls_to_anthropic_tool_use_blocks, @@ -338,7 +337,7 @@ def _format_image(image_url: str) -> Dict: def _format_data_content_block(block: dict) -> dict: - """Format standard data content block to format expected by Converse API.""" + """Format standard data content block to format expected by Bedrock.""" if block["type"] == "image": if block["source_type"] == "base64": if "mime_type" not in block: @@ -717,13 +716,10 @@ class ChatBedrock(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" system_prompt_with_tools: str = "" - beta_use_converse_api: bool = False - """Use the new Bedrock ``converse`` API which provides a standardized interface to - all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more.""" stop_sequences: Optional[List[str]] = Field(default=None, alias="stop") - """Stop sequence inference parameter from new Bedrock ``converse`` API providing - a sequence of characters that causes a model to stop generating a response. See + """Stop sequence inference parameter providing a sequence of + characters that causes a model to stop generating a response. See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent_InferenceConfiguration.html for more.""" @@ -744,18 +740,20 @@ def get_lc_namespace(cls) -> List[str]: @model_validator(mode="before") @classmethod - def set_beta_use_converse_api(cls, values: Dict) -> Any: + def check_unsupported_model(cls, values: Dict) -> Any: model_id = values.get("model_id", values.get("model")) base_model_id = values.get("base_model_id", values.get("base_model", "")) - if not model_id or "beta_use_converse_api" in values: + if not model_id: return values - nova_id = "amazon.nova" - values["beta_use_converse_api"] = False + # Add new Bedrock models here as needed + unsupported_models = ["amazon.nova"] - if nova_id in model_id or nova_id in base_model_id: - values["beta_use_converse_api"] = True + bad_model_err = ("Provided model is unsupported on ChatBedrock with langchain-aws>=1.0.0." + " Please use ChatBedrockConverse instead.") + if any(model in model_id or model in base_model_id for model in unsupported_models): + raise ValueError(bad_model_err) elif not base_model_id and "application-inference-profile" in model_id: bedrock_client = values.get("bedrock_client") if not bedrock_client: @@ -775,7 +773,9 @@ def set_beta_use_converse_api(cls, values: Dict) -> Any: if "models" in response and len(response["models"]) > 0: model_arn = response["models"][0]["modelArn"] resolved_base_model = model_arn.split("/")[-1] - values["beta_use_converse_api"] = "nova" in resolved_base_model + values["base_model_id"] = resolved_base_model + if any(model in resolved_base_model for model in unsupported_models): + raise ValueError(bad_model_err) return values @model_validator(mode="before") @@ -834,11 +834,6 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - if self.beta_use_converse_api: - yield from self._as_converse._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return provider = self._get_provider() prompt, system, formatted_messages = None, None, None @@ -925,16 +920,6 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if self.beta_use_converse_api: - if not self.streaming: - return self._as_converse._generate( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - else: - stream_iter = self._as_converse._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) completion = "" llm_output: Dict[str, Any] = {} tool_calls: List[ToolCall] = [] @@ -1091,12 +1076,6 @@ def bind_tools( **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ - if self.beta_use_converse_api: - if isinstance(tool_choice, bool): - tool_choice = "any" if tool_choice else None - return self._as_converse.bind_tools( - tools, tool_choice=tool_choice, **kwargs - ) if self._get_provider() == "anthropic": formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] @@ -1262,10 +1241,6 @@ class AnswerWithJustification(BaseModel): # } """ # noqa: E501 - if self.beta_use_converse_api: - return self._as_converse.with_structured_output( - schema, include_raw=include_raw, **kwargs - ) if "claude-" not in self._get_base_model(): raise ValueError( f"Structured output is not supported for model {self._get_base_model()}" @@ -1298,43 +1273,3 @@ class AnswerWithJustification(BaseModel): return RunnableMap(raw=llm) | parser_with_fallback else: return llm | output_parser - - @property - def _as_converse(self) -> ChatBedrockConverse: - kwargs = { - k: v - for k, v in (self.model_kwargs or {}).items() - if k - in ( - "stop", - "stop_sequences", - "max_tokens", - "temperature", - "top_p", - "additional_model_request_fields", - "additional_model_response_field_paths", - "performance_config", - "request_metadata", - ) - } - if self.max_tokens: - kwargs["max_tokens"] = self.max_tokens - if self.temperature is not None: - kwargs["temperature"] = self.temperature - if self.stop_sequences: - kwargs["stop_sequences"] = self.stop_sequences - - return ChatBedrockConverse( - client=self.client, - model=self.model_id, - region_name=self.region_name, - credentials_profile_name=self.credentials_profile_name, - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - aws_session_token=self.aws_session_token, - config=self.config, - provider=self.provider or "", - base_url=self.endpoint_url, - guardrail_config=(self.guardrails if self._guardrails_enabled else None), # type: ignore[call-arg] - **kwargs, - ) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 6129541ca..0afe29509 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -594,7 +594,6 @@ def test_guardrails() -> None: "guardrailVersion": "1", "trace": "enabled", }, - "beta_use_converse_api": True, } chat_model = ChatBedrock(**params) # type: ignore[arg-type] messages = [ @@ -684,7 +683,6 @@ def test_guardrails_streaming_trace() -> None: guardrails=guardrail_config, callbacks=[guardrail_callback], region_name="us-west-2", - beta_use_converse_api=False, # Use legacy API for this test ) # type: ignore[call-arg] # Test message that should trigger guardrail intervention @@ -698,7 +696,6 @@ def test_guardrails_streaming_trace() -> None: guardrails=guardrail_config, callbacks=[invoke_callback], region_name="us-west-2", - beta_use_converse_api=False, ) # type: ignore[call-arg] try: diff --git a/libs/aws/tests/integration_tests/chat_models/test_standard.py b/libs/aws/tests/integration_tests/chat_models/test_standard.py index cc3cc2202..f1067f060 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_standard.py +++ b/libs/aws/tests/integration_tests/chat_models/test_standard.py @@ -29,31 +29,3 @@ def supports_image_inputs(self) -> bool: @pytest.mark.xfail(reason="Not implemented.") def test_double_messages_conversation(self, model: BaseChatModel) -> None: super().test_double_messages_conversation(model) - - -class TestBedrockUseConverseStandard(ChatModelIntegrationTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatBedrock - - @property - def chat_model_params(self) -> dict: - return { - "model_id": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "beta_use_converse_api": True, - } - - @property - def standard_chat_model_params(self) -> dict: - return { - "temperature": 0, - "max_tokens": 100, - "stop_sequences": [], - "model_kwargs": { - "stop": [], - }, - } - - @property - def supports_image_inputs(self) -> bool: - return True diff --git a/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr index 88098c8fa..4f5732335 100644 --- a/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr @@ -1,47 +1,4 @@ # serializer version: 1 -# name: TestBedrockAsConverseStandard.test_serdes[serialized] - dict({ - 'id': list([ - 'langchain', - 'chat_models', - 'bedrock', - 'ChatBedrock', - ]), - 'kwargs': dict({ - 'beta_use_converse_api': True, - 'guardrails': dict({ - 'guardrailIdentifier': None, - 'guardrailVersion': None, - 'trace': None, - }), - 'max_tokens': 100, - 'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0', - 'model_kwargs': dict({ - 'stop': list([ - ]), - }), - 'provider_stop_reason_key_map': dict({ - 'ai21': 'finishReason', - 'amazon': 'completionReason', - 'anthropic': 'stop_reason', - 'cohere': 'finish_reason', - 'mistral': 'stop_reason', - }), - 'provider_stop_sequence_key_name_map': dict({ - 'ai21': 'stop_sequences', - 'amazon': 'stopSequences', - 'anthropic': 'stop_sequences', - 'cohere': 'stop_sequences', - 'mistral': 'stop_sequences', - }), - 'region_name': 'us-east-1', - 'temperature': 0, - }), - 'lc': 1, - 'name': 'ChatBedrock', - 'type': 'constructor', - }) -# --- # name: TestBedrockStandard.test_serdes[serialized] dict({ 'id': list([ diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py index 9a70b68b5..01db94068 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableBinding from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from langchain_aws import ChatBedrock from langchain_aws.chat_models.bedrock import ( @@ -639,106 +639,78 @@ def test_standard_tracing_params() -> None: } -def test_beta_use_converse_api() -> None: - llm = ChatBedrock(model_id="amazon.nova.foo", region_name="us-west-2") # type: ignore[call-arg] - assert llm.beta_use_converse_api +def test_check_unsupported_model() -> None: - llm = ChatBedrock( - model="foobar", base_model="amazon.nova.foo", region_name="us-west-2" - ) # type: ignore[call-arg] - assert llm.beta_use_converse_api + with pytest.raises(ValidationError): + ChatBedrock(model="amazon.nova.foo", + region_name="us-west-2") # type: ignore[call-arg] - llm = ChatBedrock( - model="arn:aws:bedrock:::application-inference-profile/my-profile", - base_model="claude.foo", - region_name="us-west-2", - ) # type: ignore[call-arg] - assert not llm.beta_use_converse_api + with pytest.raises(ValidationError): + ChatBedrock(model="foobar", + base_model="amazon.nova.foo", + region_name="us-west-2") # type: ignore[call-arg] - llm = ChatBedrock( - model="nova.foo", region_name="us-west-2", beta_use_converse_api=False - ) - assert not llm.beta_use_converse_api + try: + ChatBedrock(model="foobar", + base_model="anthropic.claude-3-7", + region_name="us-west-2") # type: ignore[call-arg] - llm = ChatBedrock( - model="foobar", - base_model="nova.foo", - region_name="us-west-2", - beta_use_converse_api=False, - ) - assert not llm.beta_use_converse_api - - llm = ChatBedrock(model="foo", region_name="us-west-2", beta_use_converse_api=True) - assert llm.beta_use_converse_api - - llm = ChatBedrock(model="foo", region_name="us-west-2", beta_use_converse_api=False) - assert not llm.beta_use_converse_api + ChatBedrock(model="anthropic.claude-3-7", + region_name="us-west-2") # type: ignore[call-arg] + except Exception as e: + pytest.fail(e) @mock.patch("langchain_aws.chat_models.bedrock.create_aws_client") -def test_beta_use_converse_api_with_inference_profile(mock_create_aws_client): +def test_check_unsupported_model_with_inference_profile_valid_model(mock_create_aws_client): mock_bedrock_client = mock.MagicMock() mock_bedrock_client.get_inference_profile.return_value = { "models": [ { - "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0" + "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-7-sonnet-20250219-v1:0" # noqa: E501 } ] } mock_create_aws_client.return_value = mock_bedrock_client - aip_model_id = "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/my-profile" - chat = ChatBedrock( - model_id=aip_model_id, - region_name="us-west-2", + aip_model_id = "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/my-profile" # noqa: E501 + + ChatBedrock( + model=aip_model_id, + region="us-west-2", bedrock_client=mock_bedrock_client, ) # type: ignore[call-arg] - mock_bedrock_client.get_inference_profile.assert_called_with( + mock_bedrock_client.get_inference_profile.assert_called_once_with( inferenceProfileIdentifier=aip_model_id ) - assert chat.beta_use_converse_api is False - @mock.patch("langchain_aws.chat_models.bedrock.create_aws_client") -def test_beta_use_converse_api_with_inference_profile_as_nova_model( - mock_create_aws_client, -): +def test_check_unsupported_model_with_inference_profile_invalid_model(mock_create_aws_client): mock_bedrock_client = mock.MagicMock() mock_bedrock_client.get_inference_profile.return_value = { "models": [ { - "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/amazon.nova-micro-v1:0" + "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/amazon.nova-micro-v1:0" # noqa: E501 } ] } mock_create_aws_client.return_value = mock_bedrock_client - aip_model_id = "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/my-profile" - chat = ChatBedrock( - model_id=aip_model_id, - region_name="us-west-2", - bedrock_client=mock_bedrock_client, - ) # type: ignore[call-arg] - - mock_bedrock_client.get_inference_profile.assert_called_with( - inferenceProfileIdentifier=aip_model_id - ) + aip_model_id = "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/my-profile" # noqa: E501 - assert chat.beta_use_converse_api is True + with pytest.raises(ValidationError): + ChatBedrock( + model=aip_model_id, + region="us-west-2", + bedrock_client=mock_bedrock_client, + ) # type: ignore[call-arg] @pytest.mark.parametrize( "model_id, provider, expected_provider, expectation, region_name", [ - ( - "amer.amazon.nova-pro-v1:0", - None, - "amazon", - nullcontext(), - "us-west-2", - ), ( "global.anthropic.claude-sonnet-4-20250514-v1:0", None, diff --git a/libs/aws/tests/unit_tests/test_standard.py b/libs/aws/tests/unit_tests/test_standard.py index 77a829715..22166134b 100644 --- a/libs/aws/tests/unit_tests/test_standard.py +++ b/libs/aws/tests/unit_tests/test_standard.py @@ -28,31 +28,3 @@ def standard_chat_model_params(self) -> dict: @pytest.mark.xfail(reason="Not implemented.") def test_standard_params(self, model: BaseChatModel) -> None: super().test_standard_params(model) - - -class TestBedrockAsConverseStandard(ChatModelUnitTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatBedrock - - @property - def chat_model_params(self) -> dict: - return { - "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", - "region_name": "us-east-1", - "beta_use_converse_api": True, - } - - @property - def standard_chat_model_params(self) -> dict: - return { - "model_kwargs": { - "temperature": 0, - "max_tokens": 100, - "stop": [], - } - } - - @pytest.mark.xfail(reason="Not implemented.") - def test_standard_params(self, model: BaseChatModel) -> None: - super().test_standard_params(model) From bcfc53851c8127918cb6bc432fd9c0464f0c7dcc Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Wed, 22 Oct 2025 22:10:51 -0700 Subject: [PATCH 2/2] fmt/test fixes --- libs/aws/langchain_aws/chat_models/bedrock.py | 10 ++- .../unit_tests/chat_models/test_bedrock.py | 81 +++++++++---------- 2 files changed, 44 insertions(+), 47 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index a94e1a59a..204e66cd8 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -795,9 +795,13 @@ def check_unsupported_model(cls, values: Dict) -> Any: # Add new Bedrock models here as needed unsupported_models = ["amazon.nova"] - bad_model_err = ("Provided model is unsupported on ChatBedrock with langchain-aws>=1.0.0." - " Please use ChatBedrockConverse instead.") - if any(model in model_id or model in base_model_id for model in unsupported_models): + bad_model_err = ( + "Provided model is unsupported on ChatBedrock with langchain-aws>=1.0.0." + " Please use ChatBedrockConverse instead." + ) + if any( + model in model_id or model in base_model_id for model in unsupported_models + ): raise ValueError(bad_model_err) elif not base_model_id and "application-inference-profile" in model_id: bedrock_client = values.get("bedrock_client") diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py index bbadf4885..6d0040db7 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -642,43 +642,42 @@ def test_standard_tracing_params() -> None: def test_check_unsupported_model() -> None: - with pytest.raises(ValidationError): - ChatBedrock(model="amazon.nova.foo", - region_name="us-west-2") # type: ignore[call-arg] + ChatBedrock(model="amazon.nova.foo", region_name="us-west-2") # type: ignore[call-arg] with pytest.raises(ValidationError): - ChatBedrock(model="foobar", - base_model="amazon.nova.foo", - region_name="us-west-2") # type: ignore[call-arg] + ChatBedrock( + model="foobar", base_model="amazon.nova.foo", region_name="us-west-2" + ) # type: ignore[call-arg] try: - ChatBedrock(model="foobar", - base_model="anthropic.claude-3-7", - region_name="us-west-2") # type: ignore[call-arg] + ChatBedrock( + model="foobar", base_model="anthropic.claude-3-7", region_name="us-west-2" + ) # type: ignore[call-arg] - ChatBedrock(model="anthropic.claude-3-7", - region_name="us-west-2") # type: ignore[call-arg] + ChatBedrock(model="anthropic.claude-3-7", region_name="us-west-2") # type: ignore[call-arg] except Exception as e: pytest.fail(e) @mock.patch("langchain_aws.chat_models.bedrock.create_aws_client") -def test_check_unsupported_model_with_inference_profile_valid_model(mock_create_aws_client): +def test_check_unsupported_model_with_inference_profile_valid_model( + mock_create_aws_client, +): mock_bedrock_client = mock.MagicMock() mock_bedrock_client.get_inference_profile.return_value = { "models": [ { - "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-7-sonnet-20250219-v1:0" # noqa: E501 + "modelArn": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-7-sonnet-20250219-v1:0" # noqa: E501 } ] } mock_create_aws_client.return_value = mock_bedrock_client aip_model_id = "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/my-profile" # noqa: E501 - chat = ChatBedrock( - model_id=aip_model_id, - region_name="us-west-2", + ChatBedrock( + model=aip_model_id, + region="us-west-2", bedrock_client=mock_bedrock_client, ) # type: ignore[call-arg] @@ -688,7 +687,9 @@ def test_check_unsupported_model_with_inference_profile_valid_model(mock_create_ @mock.patch("langchain_aws.chat_models.bedrock.create_aws_client") -def test_check_unsupported_model_with_inference_profile_invalid_model(mock_create_aws_client): +def test_check_unsupported_model_with_inference_profile_invalid_model( + mock_create_aws_client, +): mock_bedrock_client = mock.MagicMock() mock_bedrock_client.get_inference_profile.return_value = { "models": [ @@ -700,16 +701,6 @@ def test_check_unsupported_model_with_inference_profile_invalid_model(mock_creat mock_create_aws_client.return_value = mock_bedrock_client aip_model_id = "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/my-profile" # noqa: E501 - chat = ChatBedrock( - model_id=aip_model_id, - region_name="us-west-2", - bedrock_client=mock_bedrock_client, - ) # type: ignore[call-arg] - - mock_bedrock_client.get_inference_profile.assert_called_with( - inferenceProfileIdentifier=aip_model_id - ) - with pytest.raises(ValidationError): ChatBedrock( model=aip_model_id, @@ -717,6 +708,10 @@ def test_check_unsupported_model_with_inference_profile_invalid_model(mock_creat bedrock_client=mock_bedrock_client, ) # type: ignore[call-arg] + mock_bedrock_client.get_inference_profile.assert_called_with( + inferenceProfileIdentifier=aip_model_id + ) + @pytest.mark.parametrize( "model_id, provider, expected_provider, expectation, region_name", @@ -783,7 +778,7 @@ def test_check_unsupported_model_with_inference_profile_invalid_model(mock_creat def test__get_provider( model_id, provider, expected_provider, expectation, region_name ) -> None: - llm = ChatBedrock(model_id=model_id, provider=provider, region_name=region_name) + llm = ChatBedrock(model=model_id, provider=provider, region=region_name) with expectation: assert llm._get_provider() == expected_provider @@ -791,9 +786,7 @@ def test__get_provider( @mock.patch.dict(os.environ, {"AWS_REGION": "us-west-1"}) def test_chat_bedrock_different_regions() -> None: region = "ap-south-2" - llm = ChatBedrock( - model_id="anthropic.claude-3-sonnet-20240229-v1:0", region_name=region - ) + llm = ChatBedrock(model="anthropic.claude-3-sonnet-20240229-v1:0", region=region) assert llm.region_name == region @@ -1188,10 +1181,10 @@ def test_chat_prompt_adapter_with_model_detection( ] chat = ChatBedrock( - model_id=model_id, - base_model_id=base_model_id, + model=model_id, + base_model=base_model_id, provider=provider, - region_name="us-west-2", + region="us-west-2", ) model_name = chat._get_base_model() @@ -1320,8 +1313,8 @@ def test__format_anthropic_messages_mixed_type_blocks_and_empty_content() -> Non def test_model_kwargs() -> None: """Test we can transfer unknown params to model_kwargs.""" llm = ChatBedrock( - model_id="my-model", - region_name="us-west-2", + model="my-model", + region="us-west-2", model_kwargs={"foo": "bar"}, ) assert llm.model_id == "my-model" @@ -1330,9 +1323,9 @@ def test_model_kwargs() -> None: with pytest.warns(match="transferred to model_kwargs"): llm = ChatBedrock( - model_id="my-model", - region_name="us-west-2", - foo="bar", + model="my-model", + region="us-west-2", + foo="bar", # type: ignore[call-arg] ) assert llm.model_id == "my-model" assert llm.region_name == "us-west-2" @@ -1340,9 +1333,9 @@ def test_model_kwargs() -> None: with pytest.warns(match="transferred to model_kwargs"): llm = ChatBedrock( - model_id="my-model", - region_name="us-west-2", - foo="bar", + model="my-model", + region="us-west-2", + foo="bar", # type: ignore[call-arg] model_kwargs={"baz": "qux"}, ) assert llm.model_id == "my-model" @@ -1352,8 +1345,8 @@ def test_model_kwargs() -> None: # For backward compatibility, test that we don't transfer known parameters out # of model_kwargs llm = ChatBedrock( - model_id="my-model", - region_name="us-west-2", + model="my-model", + region="us-west-2", model_kwargs={"stop_sequences": ["test"]}, ) assert llm.model_kwargs == {"stop_sequences": ["test"]}