diff --git a/docs/my-website/docs/proxy/guardrails/bedrock.md b/docs/my-website/docs/proxy/guardrails/bedrock.md index 4a1a0a246f82..8c71508fd235 100644 --- a/docs/my-website/docs/proxy/guardrails/bedrock.md +++ b/docs/my-website/docs/proxy/guardrails/bedrock.md @@ -188,6 +188,28 @@ My email is [EMAIL] and my phone number is [PHONE_NUMBER] This helps protect sensitive information while still allowing the model to understand the context of the request. +## Experimental: Only Send Latest User Message + +When you're chaining long conversations through Bedrock guardrails, you can opt into a lighter, experimental behavior by setting `experimental_use_latest_role_message_only: true` in the guardrail's `litellm_params`. When enabled, LiteLLM only sends the most recent `user` message (or assistant output during post-call checks) to Bedrock, which: + +- prevents unintended blocks on older system/dev messages +- keeps Bedrock payloads smaller, reducing latency and cost +- applies to proxy hooks (`pre_call`, `during_call`) and the `/guardrails/apply_guardrail` testing endpoint + +```yaml showLineNumbers title="litellm proxy config.yaml" +guardrails: + - guardrail_name: "bedrock-pre-guard" + litellm_params: + guardrail: bedrock + mode: "pre_call" + guardrailIdentifier: wf0hkdb5x07f + guardrailVersion: "DRAFT" + aws_region_name: os.environ/AWS_REGION + experimental_use_latest_role_message_only: true # NEW +``` + +> ⚠️ This flag is currently experimental and defaults to `false` to preserve the legacy behavior (entire message history). We'll be listening to user feedback to decide if this becomes the default or rolls out more broadly. + ## Disabling Exceptions on Bedrock BLOCK By default, when Bedrock guardrails block content, LiteLLM raises an HTTP 400 exception. However, you can disable this behavior by setting `disable_exception_on_block: true`. This is particularly useful when integrating with **OpenWebUI**, where exceptions can interrupt the chat flow and break the user experience. diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index e8460528de52..1a67da632041 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -14,7 +14,17 @@ ) # Adds the parent directory to the system path import json import sys -from typing import Any, AsyncGenerator, List, Literal, Optional, Tuple, Union, cast +from typing import ( + Any, + AsyncGenerator, + List, + Literal, + NamedTuple, + Optional, + Tuple, + Union, + cast, +) import httpx from fastapi import HTTPException @@ -54,6 +64,12 @@ GUARDRAIL_NAME = "bedrock" +class GuardrailMessageFilterResult(NamedTuple): + payload_messages: Optional[List[AllMessageValues]] + original_messages: Optional[List[AllMessageValues]] + target_indices: Optional[List[int]] + + def _redact_pii_matches(response_json: dict) -> dict: try: # Create a deep copy to avoid modifying the original response @@ -113,6 +129,9 @@ def __init__( self.guardrailIdentifier = guardrailIdentifier self.guardrailVersion = guardrailVersion self.guardrail_provider = "bedrock" + self.experimental_use_latest_role_message_only = bool( + kwargs.get("experimental_use_latest_role_message_only") + ) # store kwargs as optional_params self.optional_params = kwargs @@ -213,6 +232,65 @@ def convert_to_bedrock_format( ) return bedrock_request + def _prepare_guardrail_messages_for_role( + self, + messages: Optional[List[AllMessageValues]], + ) -> GuardrailMessageFilterResult: + """Return payload + merge metadata for the latest user message.""" + # NOTE: This logic probably belongs in CustomGuardrail once other guardrails adopt the feature. + + if messages is None: + return GuardrailMessageFilterResult(None, None, None) + + if self.experimental_use_latest_role_message_only is not True: + return GuardrailMessageFilterResult(messages, None, None) + + latest_index = self._find_latest_message_index(messages, target_role="user") + if latest_index is None: + return GuardrailMessageFilterResult(None, None, None) + + original_messages = list(messages) + payload_messages = [messages[latest_index]] + return GuardrailMessageFilterResult( + payload_messages=payload_messages, + original_messages=original_messages, + target_indices=[latest_index], + ) + + def _find_latest_message_index( + self, messages: List[AllMessageValues], target_role: str + ) -> Optional[int]: + for index in range(len(messages) - 1, -1, -1): + if messages[index].get("role", None) == target_role: + return index + return None + + def _merge_filtered_messages( + self, + original_messages: Optional[List[AllMessageValues]], + updated_target_messages: List[AllMessageValues], + target_indices: Optional[List[int]], + ) -> List[AllMessageValues]: + if not target_indices: + return updated_target_messages + + if not original_messages: + original_messages = [] + + merged_messages = list(original_messages) + if not merged_messages: + merged_messages = list(updated_target_messages) + for replacement_index, updated_message in zip( + target_indices, updated_target_messages + ): + if replacement_index < len(merged_messages): + merged_messages[replacement_index] = updated_message + + return merged_messages + + # NOTE: Consider moving these helpers to CustomGuardrail when the filtering + # logic becomes shared across providers. + #### CALL HOOKS - proxy only #### def _load_credentials( self, @@ -635,15 +713,24 @@ async def async_pre_call_hook( ) return data + filter_result = self._prepare_guardrail_messages_for_role(messages=new_messages) + + filtered_messages = filter_result.payload_messages + if not filtered_messages: + verbose_proxy_logger.debug( + "No user-role messages available for guardrail payload" + ) + return data + ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### - bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = ( - None - ) + bedrock_guardrail_response: Optional[ + Union[BedrockGuardrailResponse, str] + ] = None try: bedrock_guardrail_response = await self.make_bedrock_api_request( - source="INPUT", messages=new_messages, request_data=data + source="INPUT", messages=filtered_messages, request_data=data ) except GuardrailInterventionNormalStringError as e: bedrock_guardrail_response = e.message @@ -652,11 +739,14 @@ async def async_pre_call_hook( ######################################################### ########## 2. Update the messages with the guardrail response ########## ######################################################### - data["messages"] = ( - self._update_messages_with_updated_bedrock_guardrail_response( - messages=new_messages, - bedrock_guardrail_response=bedrock_guardrail_response, - ) + updated_subset = self._update_messages_with_updated_bedrock_guardrail_response( + messages=filtered_messages, + bedrock_guardrail_response=bedrock_guardrail_response, + ) + data["messages"] = self._merge_filtered_messages( + original_messages=filter_result.original_messages or new_messages, + updated_target_messages=updated_subset, + target_indices=filter_result.target_indices, ) if isinstance(bedrock_guardrail_response, str): data["mock_response"] = self.create_guardrail_blocked_response( @@ -696,15 +786,23 @@ async def async_moderation_hook( ) return + filter_result = self._prepare_guardrail_messages_for_role(messages=new_messages) + filtered_messages = filter_result.payload_messages + if not filtered_messages: + verbose_proxy_logger.debug( + "Bedrock AI: not running guardrail. No user-role messages" + ) + return + ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### - bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = ( - None - ) + bedrock_guardrail_response: Optional[ + Union[BedrockGuardrailResponse, str] + ] = None try: bedrock_guardrail_response = await self.make_bedrock_api_request( - source="INPUT", messages=new_messages, request_data=data + source="INPUT", messages=filtered_messages, request_data=data ) except GuardrailInterventionNormalStringError as e: bedrock_guardrail_response = e.message @@ -713,11 +811,14 @@ async def async_moderation_hook( ######################################################### ########## 2. Update the messages with the guardrail response ########## ######################################################### - data["messages"] = ( - self._update_messages_with_updated_bedrock_guardrail_response( - messages=new_messages, - bedrock_guardrail_response=bedrock_guardrail_response, - ) + updated_subset = self._update_messages_with_updated_bedrock_guardrail_response( + messages=filtered_messages, + bedrock_guardrail_response=bedrock_guardrail_response, + ) + data["messages"] = self._merge_filtered_messages( + original_messages=filter_result.original_messages or new_messages, + updated_target_messages=updated_subset, + target_indices=filter_result.target_indices, ) if isinstance(bedrock_guardrail_response, str): data["mock_response"] = self.create_guardrail_blocked_response( @@ -887,9 +988,15 @@ async def async_post_call_streaming_iterator_hook( # Bedrock will raise an exception if this violates the guardrail policy ################################################################### # Create tasks for parallel execution + input_filter = self._prepare_guardrail_messages_for_role( + messages=request_data.get("messages") + ) + input_messages = input_filter.payload_messages or request_data.get( + "messages" + ) input_task = self.make_bedrock_api_request( source="INPUT", - messages=request_data.get("messages"), + messages=input_messages, request_data=request_data, ) # Only input messages output_guardrail_response: Optional[ @@ -1161,9 +1268,15 @@ async def apply_guardrail( if request_data is None: request_data = {"messages": mock_messages} + request_messages = request_data.get("messages") or mock_messages + filter_result = self._prepare_guardrail_messages_for_role( + messages=request_messages + ) + filtered_messages = filter_result.payload_messages or mock_messages + bedrock_response = await self.make_bedrock_api_request( source="INPUT", - messages=mock_messages, + messages=filtered_messages, request_data=request_data, ) diff --git a/litellm/proxy/guardrails/guardrail_initializers.py b/litellm/proxy/guardrails/guardrail_initializers.py index 9bb965ef14e5..ea2434f5e724 100644 --- a/litellm/proxy/guardrails/guardrail_initializers.py +++ b/litellm/proxy/guardrails/guardrail_initializers.py @@ -30,6 +30,7 @@ def initialize_bedrock(litellm_params: LitellmParams, guardrail: Guardrail): aws_web_identity_token=litellm_params.aws_web_identity_token, aws_sts_endpoint=litellm_params.aws_sts_endpoint, aws_bedrock_runtime_endpoint=litellm_params.aws_bedrock_runtime_endpoint, + experimental_use_latest_role_message_only=litellm_params.experimental_use_latest_role_message_only, ) litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback) return _bedrock_callback diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 24a235def597..c612ff667d94 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -522,6 +522,11 @@ class BaseLitellmParams(BaseModel): # works for new and patch update guardrails default=None, description="Base URL for the guardrail service API" ) + experimental_use_latest_role_message_only: Optional[bool] = Field( + default=False, + description="When True, guardrails only receive the latest message for the relevant role (e.g., newest user input pre-call, newest assistant output post-call)", + ) + # Lakera specific params category_thresholds: Optional[LakeraCategoryThresholds] = Field( default=None, diff --git a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py b/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py index 2cf679966246..e65d01e41e45 100644 --- a/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py +++ b/tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py @@ -188,3 +188,95 @@ async def test_bedrock_apply_guardrail_endpoint_integration(): assert isinstance(response, ApplyGuardrailResponse) assert response.response_text == "This is a test message with processed content" mock_api_request.assert_called_once() + + +@pytest.mark.asyncio +async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled(): + guardrail = BedrockGuardrail( + guardrail_name="test-bedrock-guard", + guardrailIdentifier="test-guard-id", + guardrailVersion="DRAFT", + experimental_use_latest_role_message_only=True, + ) + + request_messages = [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": "response"}, + {"role": "user", "content": "latest question"}, + ] + + request_data = {"messages": request_messages} + + with patch.object(guardrail, "make_bedrock_api_request", new_callable=AsyncMock) as mock_api: + mock_api.return_value = {"action": "ALLOWED"} + + result = await guardrail.apply_guardrail( + text="latest question", + request_data=request_data, + ) + + assert mock_api.called + _, kwargs = mock_api.call_args + assert kwargs["messages"] == [request_messages[-1]] + assert result == "latest question" + + +@pytest.mark.asyncio +async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled_blocked(): + guardrail = BedrockGuardrail( + guardrail_name="test-bedrock-guard", + guardrailIdentifier="test-guard-id", + guardrailVersion="DRAFT", + experimental_use_latest_role_message_only=True, + ) + + request_messages = [ + {"role": "user", "content": "first"}, + {"role": "user", "content": "blocked"}, + ] + + request_data = {"messages": request_messages} + + with patch.object(guardrail, "make_bedrock_api_request", new_callable=AsyncMock) as mock_api: + mock_api.return_value = {"action": "BLOCKED", "reason": "policy"} + + with pytest.raises(Exception, match="policy") as exc_info: + await guardrail.apply_guardrail( + text="blocked", + request_data=request_data, + ) + + assert mock_api.called + _, kwargs = mock_api.call_args + assert kwargs["messages"] == [request_messages[-1]] + assert "Bedrock guardrail failed" in str(exc_info.value) + +def test_bedrock_guardrail_filters_latest_user_message_when_enabled(): + guardrail = BedrockGuardrail( + guardrail_name="test-bedrock-guard", + guardrailIdentifier="test-guard-id", + guardrailVersion="DRAFT", + experimental_use_latest_role_message_only=True, + ) + + messages = [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": "response"}, + {"role": "user", "content": "latest question"}, + ] + + filter_result = guardrail._prepare_guardrail_messages_for_role(messages=messages) + + assert filter_result.payload_messages is not None + assert len(filter_result.payload_messages) == 1 + assert filter_result.payload_messages[0]["content"] == "latest question" + assert filter_result.target_indices == [3] + + masked_messages = guardrail._merge_filtered_messages( + original_messages=filter_result.original_messages, + updated_target_messages=[{"role": "user", "content": "[MASKED]"}], + target_indices=filter_result.target_indices, + ) + assert masked_messages[3]["content"] == "[MASKED]"