Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/my-website/docs/proxy/guardrails/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,28 @@ My email is [EMAIL] and my phone number is [PHONE_NUMBER]

This helps protect sensitive information while still allowing the model to understand the context of the request.

## Experimental: Only Send Latest User Message

When you're chaining long conversations through Bedrock guardrails, you can opt into a lighter, experimental behavior by setting `experimental_use_latest_role_message_only: true` in the guardrail's `litellm_params`. When enabled, LiteLLM only sends the most recent `user` message (or assistant output during post-call checks) to Bedrock, which:

- prevents unintended blocks on older system/dev messages
- keeps Bedrock payloads smaller, reducing latency and cost
- applies to proxy hooks (`pre_call`, `during_call`) and the `/guardrails/apply_guardrail` testing endpoint

```yaml showLineNumbers title="litellm proxy config.yaml"
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock
mode: "pre_call"
guardrailIdentifier: wf0hkdb5x07f
guardrailVersion: "DRAFT"
aws_region_name: os.environ/AWS_REGION
experimental_use_latest_role_message_only: true # NEW
```

> ⚠️ This flag is currently experimental and defaults to `false` to preserve the legacy behavior (entire message history). We'll be listening to user feedback to decide if this becomes the default or rolls out more broadly.

## Disabling Exceptions on Bedrock BLOCK

By default, when Bedrock guardrails block content, LiteLLM raises an HTTP 400 exception. However, you can disable this behavior by setting `disable_exception_on_block: true`. This is particularly useful when integrating with **OpenWebUI**, where exceptions can interrupt the chat flow and break the user experience.
Expand Down
155 changes: 134 additions & 21 deletions litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
) # Adds the parent directory to the system path
import json
import sys
from typing import Any, AsyncGenerator, List, Literal, Optional, Tuple, Union, cast
from typing import (
Any,
AsyncGenerator,
List,
Literal,
NamedTuple,
Optional,
Tuple,
Union,
cast,
)

import httpx
from fastapi import HTTPException
Expand Down Expand Up @@ -54,6 +64,12 @@
GUARDRAIL_NAME = "bedrock"


class GuardrailMessageFilterResult(NamedTuple):
payload_messages: Optional[List[AllMessageValues]]
original_messages: Optional[List[AllMessageValues]]
target_indices: Optional[List[int]]


def _redact_pii_matches(response_json: dict) -> dict:
try:
# Create a deep copy to avoid modifying the original response
Expand Down Expand Up @@ -113,6 +129,9 @@ def __init__(
self.guardrailIdentifier = guardrailIdentifier
self.guardrailVersion = guardrailVersion
self.guardrail_provider = "bedrock"
self.experimental_use_latest_role_message_only = bool(
kwargs.get("experimental_use_latest_role_message_only")
)

# store kwargs as optional_params
self.optional_params = kwargs
Expand Down Expand Up @@ -213,6 +232,65 @@ def convert_to_bedrock_format(
)
return bedrock_request

def _prepare_guardrail_messages_for_role(
self,
messages: Optional[List[AllMessageValues]],
) -> GuardrailMessageFilterResult:
"""Return payload + merge metadata for the latest user message."""
# NOTE: This logic probably belongs in CustomGuardrail once other guardrails adopt the feature.

if messages is None:
return GuardrailMessageFilterResult(None, None, None)

if self.experimental_use_latest_role_message_only is not True:
return GuardrailMessageFilterResult(messages, None, None)

latest_index = self._find_latest_message_index(messages, target_role="user")
if latest_index is None:
return GuardrailMessageFilterResult(None, None, None)

original_messages = list(messages)
payload_messages = [messages[latest_index]]
return GuardrailMessageFilterResult(
payload_messages=payload_messages,
original_messages=original_messages,
target_indices=[latest_index],
)

def _find_latest_message_index(
self, messages: List[AllMessageValues], target_role: str
) -> Optional[int]:
for index in range(len(messages) - 1, -1, -1):
if messages[index].get("role", None) == target_role:
return index
return None

def _merge_filtered_messages(
self,
original_messages: Optional[List[AllMessageValues]],
updated_target_messages: List[AllMessageValues],
target_indices: Optional[List[int]],
) -> List[AllMessageValues]:
if not target_indices:
return updated_target_messages

if not original_messages:
original_messages = []

merged_messages = list(original_messages)
if not merged_messages:
merged_messages = list(updated_target_messages)
for replacement_index, updated_message in zip(
target_indices, updated_target_messages
):
if replacement_index < len(merged_messages):
merged_messages[replacement_index] = updated_message

return merged_messages

# NOTE: Consider moving these helpers to CustomGuardrail when the filtering
# logic becomes shared across providers.

#### CALL HOOKS - proxy only ####
def _load_credentials(
self,
Expand Down Expand Up @@ -635,15 +713,24 @@ async def async_pre_call_hook(
)
return data

filter_result = self._prepare_guardrail_messages_for_role(messages=new_messages)

filtered_messages = filter_result.payload_messages
if not filtered_messages:
verbose_proxy_logger.debug(
"No user-role messages available for guardrail payload"
)
return data

#########################################################
########## 1. Make the Bedrock API request ##########
#########################################################
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
None
)
bedrock_guardrail_response: Optional[
Union[BedrockGuardrailResponse, str]
] = None
try:
bedrock_guardrail_response = await self.make_bedrock_api_request(
source="INPUT", messages=new_messages, request_data=data
source="INPUT", messages=filtered_messages, request_data=data
)
except GuardrailInterventionNormalStringError as e:
bedrock_guardrail_response = e.message
Expand All @@ -652,11 +739,14 @@ async def async_pre_call_hook(
#########################################################
########## 2. Update the messages with the guardrail response ##########
#########################################################
data["messages"] = (
self._update_messages_with_updated_bedrock_guardrail_response(
messages=new_messages,
bedrock_guardrail_response=bedrock_guardrail_response,
)
updated_subset = self._update_messages_with_updated_bedrock_guardrail_response(
messages=filtered_messages,
bedrock_guardrail_response=bedrock_guardrail_response,
)
data["messages"] = self._merge_filtered_messages(
original_messages=filter_result.original_messages or new_messages,
updated_target_messages=updated_subset,
target_indices=filter_result.target_indices,
)
if isinstance(bedrock_guardrail_response, str):
data["mock_response"] = self.create_guardrail_blocked_response(
Expand Down Expand Up @@ -696,15 +786,23 @@ async def async_moderation_hook(
)
return

filter_result = self._prepare_guardrail_messages_for_role(messages=new_messages)
filtered_messages = filter_result.payload_messages
if not filtered_messages:
verbose_proxy_logger.debug(
"Bedrock AI: not running guardrail. No user-role messages"
)
return

#########################################################
########## 1. Make the Bedrock API request ##########
#########################################################
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
None
)
bedrock_guardrail_response: Optional[
Union[BedrockGuardrailResponse, str]
] = None
try:
bedrock_guardrail_response = await self.make_bedrock_api_request(
source="INPUT", messages=new_messages, request_data=data
source="INPUT", messages=filtered_messages, request_data=data
)
except GuardrailInterventionNormalStringError as e:
bedrock_guardrail_response = e.message
Expand All @@ -713,11 +811,14 @@ async def async_moderation_hook(
#########################################################
########## 2. Update the messages with the guardrail response ##########
#########################################################
data["messages"] = (
self._update_messages_with_updated_bedrock_guardrail_response(
messages=new_messages,
bedrock_guardrail_response=bedrock_guardrail_response,
)
updated_subset = self._update_messages_with_updated_bedrock_guardrail_response(
messages=filtered_messages,
bedrock_guardrail_response=bedrock_guardrail_response,
)
data["messages"] = self._merge_filtered_messages(
original_messages=filter_result.original_messages or new_messages,
updated_target_messages=updated_subset,
target_indices=filter_result.target_indices,
)
if isinstance(bedrock_guardrail_response, str):
data["mock_response"] = self.create_guardrail_blocked_response(
Expand Down Expand Up @@ -887,9 +988,15 @@ async def async_post_call_streaming_iterator_hook(
# Bedrock will raise an exception if this violates the guardrail policy
###################################################################
# Create tasks for parallel execution
input_filter = self._prepare_guardrail_messages_for_role(
messages=request_data.get("messages")
)
input_messages = input_filter.payload_messages or request_data.get(
"messages"
)
input_task = self.make_bedrock_api_request(
source="INPUT",
messages=request_data.get("messages"),
messages=input_messages,
request_data=request_data,
) # Only input messages
output_guardrail_response: Optional[
Expand Down Expand Up @@ -1161,9 +1268,15 @@ async def apply_guardrail(
if request_data is None:
request_data = {"messages": mock_messages}

request_messages = request_data.get("messages") or mock_messages
filter_result = self._prepare_guardrail_messages_for_role(
messages=request_messages
)
filtered_messages = filter_result.payload_messages or mock_messages

bedrock_response = await self.make_bedrock_api_request(
source="INPUT",
messages=mock_messages,
messages=filtered_messages,
request_data=request_data,
)

Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/guardrails/guardrail_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def initialize_bedrock(litellm_params: LitellmParams, guardrail: Guardrail):
aws_web_identity_token=litellm_params.aws_web_identity_token,
aws_sts_endpoint=litellm_params.aws_sts_endpoint,
aws_bedrock_runtime_endpoint=litellm_params.aws_bedrock_runtime_endpoint,
experimental_use_latest_role_message_only=litellm_params.experimental_use_latest_role_message_only,
)
litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback)
return _bedrock_callback
Expand Down
5 changes: 5 additions & 0 deletions litellm/types/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,11 @@ class BaseLitellmParams(BaseModel): # works for new and patch update guardrails
default=None, description="Base URL for the guardrail service API"
)

experimental_use_latest_role_message_only: Optional[bool] = Field(
default=False,
description="When True, guardrails only receive the latest message for the relevant role (e.g., newest user input pre-call, newest assistant output post-call)",
)

# Lakera specific params
category_thresholds: Optional[LakeraCategoryThresholds] = Field(
default=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,95 @@ async def test_bedrock_apply_guardrail_endpoint_integration():
assert isinstance(response, ApplyGuardrailResponse)
assert response.response_text == "This is a test message with processed content"
mock_api_request.assert_called_once()


@pytest.mark.asyncio
async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled():
guardrail = BedrockGuardrail(
guardrail_name="test-bedrock-guard",
guardrailIdentifier="test-guard-id",
guardrailVersion="DRAFT",
experimental_use_latest_role_message_only=True,
)

request_messages = [
{"role": "system", "content": "rules"},
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "response"},
{"role": "user", "content": "latest question"},
]

request_data = {"messages": request_messages}

with patch.object(guardrail, "make_bedrock_api_request", new_callable=AsyncMock) as mock_api:
mock_api.return_value = {"action": "ALLOWED"}

result = await guardrail.apply_guardrail(
text="latest question",
request_data=request_data,
)

assert mock_api.called
_, kwargs = mock_api.call_args
assert kwargs["messages"] == [request_messages[-1]]
assert result == "latest question"


@pytest.mark.asyncio
async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled_blocked():
guardrail = BedrockGuardrail(
guardrail_name="test-bedrock-guard",
guardrailIdentifier="test-guard-id",
guardrailVersion="DRAFT",
experimental_use_latest_role_message_only=True,
)

request_messages = [
{"role": "user", "content": "first"},
{"role": "user", "content": "blocked"},
]

request_data = {"messages": request_messages}

with patch.object(guardrail, "make_bedrock_api_request", new_callable=AsyncMock) as mock_api:
mock_api.return_value = {"action": "BLOCKED", "reason": "policy"}

with pytest.raises(Exception, match="policy") as exc_info:
await guardrail.apply_guardrail(
text="blocked",
request_data=request_data,
)

assert mock_api.called
_, kwargs = mock_api.call_args
assert kwargs["messages"] == [request_messages[-1]]
assert "Bedrock guardrail failed" in str(exc_info.value)

def test_bedrock_guardrail_filters_latest_user_message_when_enabled():
guardrail = BedrockGuardrail(
guardrail_name="test-bedrock-guard",
guardrailIdentifier="test-guard-id",
guardrailVersion="DRAFT",
experimental_use_latest_role_message_only=True,
)

messages = [
{"role": "system", "content": "rules"},
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "response"},
{"role": "user", "content": "latest question"},
]

filter_result = guardrail._prepare_guardrail_messages_for_role(messages=messages)

assert filter_result.payload_messages is not None
assert len(filter_result.payload_messages) == 1
assert filter_result.payload_messages[0]["content"] == "latest question"
assert filter_result.target_indices == [3]

masked_messages = guardrail._merge_filtered_messages(
original_messages=filter_result.original_messages,
updated_target_messages=[{"role": "user", "content": "[MASKED]"}],
target_indices=filter_result.target_indices,
)
assert masked_messages[3]["content"] == "[MASKED]"
Loading