Skip to content

Commit da5b81c

Browse files
authored
feat: add experimental latest-user filtering for Bedrock (#17282)
* feat: add experimental latest-user filtering for Bedrock * doc: add experimental bedrock latest-message flag
1 parent 860270a commit da5b81c

File tree

5 files changed

+254
-21
lines changed

5 files changed

+254
-21
lines changed

docs/my-website/docs/proxy/guardrails/bedrock.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,28 @@ My email is [EMAIL] and my phone number is [PHONE_NUMBER]
188188

189189
This helps protect sensitive information while still allowing the model to understand the context of the request.
190190

191+
## Experimental: Only Send Latest User Message
192+
193+
When you're chaining long conversations through Bedrock guardrails, you can opt into a lighter, experimental behavior by setting `experimental_use_latest_role_message_only: true` in the guardrail's `litellm_params`. When enabled, LiteLLM only sends the most recent `user` message (or assistant output during post-call checks) to Bedrock, which:
194+
195+
- prevents unintended blocks on older system/dev messages
196+
- keeps Bedrock payloads smaller, reducing latency and cost
197+
- applies to proxy hooks (`pre_call`, `during_call`) and the `/guardrails/apply_guardrail` testing endpoint
198+
199+
```yaml showLineNumbers title="litellm proxy config.yaml"
200+
guardrails:
201+
- guardrail_name: "bedrock-pre-guard"
202+
litellm_params:
203+
guardrail: bedrock
204+
mode: "pre_call"
205+
guardrailIdentifier: wf0hkdb5x07f
206+
guardrailVersion: "DRAFT"
207+
aws_region_name: os.environ/AWS_REGION
208+
experimental_use_latest_role_message_only: true # NEW
209+
```
210+
211+
> ⚠️ This flag is currently experimental and defaults to `false` to preserve the legacy behavior (entire message history). We'll be listening to user feedback to decide if this becomes the default or rolls out more broadly.
212+
191213
## Disabling Exceptions on Bedrock BLOCK
192214

193215
By default, when Bedrock guardrails block content, LiteLLM raises an HTTP 400 exception. However, you can disable this behavior by setting `disable_exception_on_block: true`. This is particularly useful when integrating with **OpenWebUI**, where exceptions can interrupt the chat flow and break the user experience.

litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py

Lines changed: 134 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
) # Adds the parent directory to the system path
1515
import json
1616
import sys
17-
from typing import Any, AsyncGenerator, List, Literal, Optional, Tuple, Union, cast
17+
from typing import (
18+
Any,
19+
AsyncGenerator,
20+
List,
21+
Literal,
22+
NamedTuple,
23+
Optional,
24+
Tuple,
25+
Union,
26+
cast,
27+
)
1828

1929
import httpx
2030
from fastapi import HTTPException
@@ -54,6 +64,12 @@
5464
GUARDRAIL_NAME = "bedrock"
5565

5666

67+
class GuardrailMessageFilterResult(NamedTuple):
68+
payload_messages: Optional[List[AllMessageValues]]
69+
original_messages: Optional[List[AllMessageValues]]
70+
target_indices: Optional[List[int]]
71+
72+
5773
def _redact_pii_matches(response_json: dict) -> dict:
5874
try:
5975
# Create a deep copy to avoid modifying the original response
@@ -113,6 +129,9 @@ def __init__(
113129
self.guardrailIdentifier = guardrailIdentifier
114130
self.guardrailVersion = guardrailVersion
115131
self.guardrail_provider = "bedrock"
132+
self.experimental_use_latest_role_message_only = bool(
133+
kwargs.get("experimental_use_latest_role_message_only")
134+
)
116135

117136
# store kwargs as optional_params
118137
self.optional_params = kwargs
@@ -213,6 +232,65 @@ def convert_to_bedrock_format(
213232
)
214233
return bedrock_request
215234

235+
def _prepare_guardrail_messages_for_role(
236+
self,
237+
messages: Optional[List[AllMessageValues]],
238+
) -> GuardrailMessageFilterResult:
239+
"""Return payload + merge metadata for the latest user message."""
240+
# NOTE: This logic probably belongs in CustomGuardrail once other guardrails adopt the feature.
241+
242+
if messages is None:
243+
return GuardrailMessageFilterResult(None, None, None)
244+
245+
if self.experimental_use_latest_role_message_only is not True:
246+
return GuardrailMessageFilterResult(messages, None, None)
247+
248+
latest_index = self._find_latest_message_index(messages, target_role="user")
249+
if latest_index is None:
250+
return GuardrailMessageFilterResult(None, None, None)
251+
252+
original_messages = list(messages)
253+
payload_messages = [messages[latest_index]]
254+
return GuardrailMessageFilterResult(
255+
payload_messages=payload_messages,
256+
original_messages=original_messages,
257+
target_indices=[latest_index],
258+
)
259+
260+
def _find_latest_message_index(
261+
self, messages: List[AllMessageValues], target_role: str
262+
) -> Optional[int]:
263+
for index in range(len(messages) - 1, -1, -1):
264+
if messages[index].get("role", None) == target_role:
265+
return index
266+
return None
267+
268+
def _merge_filtered_messages(
269+
self,
270+
original_messages: Optional[List[AllMessageValues]],
271+
updated_target_messages: List[AllMessageValues],
272+
target_indices: Optional[List[int]],
273+
) -> List[AllMessageValues]:
274+
if not target_indices:
275+
return updated_target_messages
276+
277+
if not original_messages:
278+
original_messages = []
279+
280+
merged_messages = list(original_messages)
281+
if not merged_messages:
282+
merged_messages = list(updated_target_messages)
283+
for replacement_index, updated_message in zip(
284+
target_indices, updated_target_messages
285+
):
286+
if replacement_index < len(merged_messages):
287+
merged_messages[replacement_index] = updated_message
288+
289+
return merged_messages
290+
291+
# NOTE: Consider moving these helpers to CustomGuardrail when the filtering
292+
# logic becomes shared across providers.
293+
216294
#### CALL HOOKS - proxy only ####
217295
def _load_credentials(
218296
self,
@@ -635,15 +713,24 @@ async def async_pre_call_hook(
635713
)
636714
return data
637715

716+
filter_result = self._prepare_guardrail_messages_for_role(messages=new_messages)
717+
718+
filtered_messages = filter_result.payload_messages
719+
if not filtered_messages:
720+
verbose_proxy_logger.debug(
721+
"No user-role messages available for guardrail payload"
722+
)
723+
return data
724+
638725
#########################################################
639726
########## 1. Make the Bedrock API request ##########
640727
#########################################################
641-
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
642-
None
643-
)
728+
bedrock_guardrail_response: Optional[
729+
Union[BedrockGuardrailResponse, str]
730+
] = None
644731
try:
645732
bedrock_guardrail_response = await self.make_bedrock_api_request(
646-
source="INPUT", messages=new_messages, request_data=data
733+
source="INPUT", messages=filtered_messages, request_data=data
647734
)
648735
except GuardrailInterventionNormalStringError as e:
649736
bedrock_guardrail_response = e.message
@@ -652,11 +739,14 @@ async def async_pre_call_hook(
652739
#########################################################
653740
########## 2. Update the messages with the guardrail response ##########
654741
#########################################################
655-
data["messages"] = (
656-
self._update_messages_with_updated_bedrock_guardrail_response(
657-
messages=new_messages,
658-
bedrock_guardrail_response=bedrock_guardrail_response,
659-
)
742+
updated_subset = self._update_messages_with_updated_bedrock_guardrail_response(
743+
messages=filtered_messages,
744+
bedrock_guardrail_response=bedrock_guardrail_response,
745+
)
746+
data["messages"] = self._merge_filtered_messages(
747+
original_messages=filter_result.original_messages or new_messages,
748+
updated_target_messages=updated_subset,
749+
target_indices=filter_result.target_indices,
660750
)
661751
if isinstance(bedrock_guardrail_response, str):
662752
data["mock_response"] = self.create_guardrail_blocked_response(
@@ -696,15 +786,23 @@ async def async_moderation_hook(
696786
)
697787
return
698788

789+
filter_result = self._prepare_guardrail_messages_for_role(messages=new_messages)
790+
filtered_messages = filter_result.payload_messages
791+
if not filtered_messages:
792+
verbose_proxy_logger.debug(
793+
"Bedrock AI: not running guardrail. No user-role messages"
794+
)
795+
return
796+
699797
#########################################################
700798
########## 1. Make the Bedrock API request ##########
701799
#########################################################
702-
bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = (
703-
None
704-
)
800+
bedrock_guardrail_response: Optional[
801+
Union[BedrockGuardrailResponse, str]
802+
] = None
705803
try:
706804
bedrock_guardrail_response = await self.make_bedrock_api_request(
707-
source="INPUT", messages=new_messages, request_data=data
805+
source="INPUT", messages=filtered_messages, request_data=data
708806
)
709807
except GuardrailInterventionNormalStringError as e:
710808
bedrock_guardrail_response = e.message
@@ -713,11 +811,14 @@ async def async_moderation_hook(
713811
#########################################################
714812
########## 2. Update the messages with the guardrail response ##########
715813
#########################################################
716-
data["messages"] = (
717-
self._update_messages_with_updated_bedrock_guardrail_response(
718-
messages=new_messages,
719-
bedrock_guardrail_response=bedrock_guardrail_response,
720-
)
814+
updated_subset = self._update_messages_with_updated_bedrock_guardrail_response(
815+
messages=filtered_messages,
816+
bedrock_guardrail_response=bedrock_guardrail_response,
817+
)
818+
data["messages"] = self._merge_filtered_messages(
819+
original_messages=filter_result.original_messages or new_messages,
820+
updated_target_messages=updated_subset,
821+
target_indices=filter_result.target_indices,
721822
)
722823
if isinstance(bedrock_guardrail_response, str):
723824
data["mock_response"] = self.create_guardrail_blocked_response(
@@ -887,9 +988,15 @@ async def async_post_call_streaming_iterator_hook(
887988
# Bedrock will raise an exception if this violates the guardrail policy
888989
###################################################################
889990
# Create tasks for parallel execution
991+
input_filter = self._prepare_guardrail_messages_for_role(
992+
messages=request_data.get("messages")
993+
)
994+
input_messages = input_filter.payload_messages or request_data.get(
995+
"messages"
996+
)
890997
input_task = self.make_bedrock_api_request(
891998
source="INPUT",
892-
messages=request_data.get("messages"),
999+
messages=input_messages,
8931000
request_data=request_data,
8941001
) # Only input messages
8951002
output_guardrail_response: Optional[
@@ -1161,9 +1268,15 @@ async def apply_guardrail(
11611268
if request_data is None:
11621269
request_data = {"messages": mock_messages}
11631270

1271+
request_messages = request_data.get("messages") or mock_messages
1272+
filter_result = self._prepare_guardrail_messages_for_role(
1273+
messages=request_messages
1274+
)
1275+
filtered_messages = filter_result.payload_messages or mock_messages
1276+
11641277
bedrock_response = await self.make_bedrock_api_request(
11651278
source="INPUT",
1166-
messages=mock_messages,
1279+
messages=filtered_messages,
11671280
request_data=request_data,
11681281
)
11691282

litellm/proxy/guardrails/guardrail_initializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def initialize_bedrock(litellm_params: LitellmParams, guardrail: Guardrail):
3030
aws_web_identity_token=litellm_params.aws_web_identity_token,
3131
aws_sts_endpoint=litellm_params.aws_sts_endpoint,
3232
aws_bedrock_runtime_endpoint=litellm_params.aws_bedrock_runtime_endpoint,
33+
experimental_use_latest_role_message_only=litellm_params.experimental_use_latest_role_message_only,
3334
)
3435
litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback)
3536
return _bedrock_callback

litellm/types/guardrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,11 @@ class BaseLitellmParams(BaseModel): # works for new and patch update guardrails
525525
default=None, description="Base URL for the guardrail service API"
526526
)
527527

528+
experimental_use_latest_role_message_only: Optional[bool] = Field(
529+
default=False,
530+
description="When True, guardrails only receive the latest message for the relevant role (e.g., newest user input pre-call, newest assistant output post-call)",
531+
)
532+
528533
# Lakera specific params
529534
category_thresholds: Optional[LakeraCategoryThresholds] = Field(
530535
default=None,

tests/enterprise/litellm_enterprise/proxy/guardrails/test_bedrock_apply_guardrail.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,95 @@ async def test_bedrock_apply_guardrail_endpoint_integration():
188188
assert isinstance(response, ApplyGuardrailResponse)
189189
assert response.response_text == "This is a test message with processed content"
190190
mock_api_request.assert_called_once()
191+
192+
193+
@pytest.mark.asyncio
194+
async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled():
195+
guardrail = BedrockGuardrail(
196+
guardrail_name="test-bedrock-guard",
197+
guardrailIdentifier="test-guard-id",
198+
guardrailVersion="DRAFT",
199+
experimental_use_latest_role_message_only=True,
200+
)
201+
202+
request_messages = [
203+
{"role": "system", "content": "rules"},
204+
{"role": "user", "content": "first question"},
205+
{"role": "assistant", "content": "response"},
206+
{"role": "user", "content": "latest question"},
207+
]
208+
209+
request_data = {"messages": request_messages}
210+
211+
with patch.object(guardrail, "make_bedrock_api_request", new_callable=AsyncMock) as mock_api:
212+
mock_api.return_value = {"action": "ALLOWED"}
213+
214+
result = await guardrail.apply_guardrail(
215+
text="latest question",
216+
request_data=request_data,
217+
)
218+
219+
assert mock_api.called
220+
_, kwargs = mock_api.call_args
221+
assert kwargs["messages"] == [request_messages[-1]]
222+
assert result == "latest question"
223+
224+
225+
@pytest.mark.asyncio
226+
async def test_bedrock_apply_guardrail_filters_request_messages_when_flag_enabled_blocked():
227+
guardrail = BedrockGuardrail(
228+
guardrail_name="test-bedrock-guard",
229+
guardrailIdentifier="test-guard-id",
230+
guardrailVersion="DRAFT",
231+
experimental_use_latest_role_message_only=True,
232+
)
233+
234+
request_messages = [
235+
{"role": "user", "content": "first"},
236+
{"role": "user", "content": "blocked"},
237+
]
238+
239+
request_data = {"messages": request_messages}
240+
241+
with patch.object(guardrail, "make_bedrock_api_request", new_callable=AsyncMock) as mock_api:
242+
mock_api.return_value = {"action": "BLOCKED", "reason": "policy"}
243+
244+
with pytest.raises(Exception, match="policy") as exc_info:
245+
await guardrail.apply_guardrail(
246+
text="blocked",
247+
request_data=request_data,
248+
)
249+
250+
assert mock_api.called
251+
_, kwargs = mock_api.call_args
252+
assert kwargs["messages"] == [request_messages[-1]]
253+
assert "Bedrock guardrail failed" in str(exc_info.value)
254+
255+
def test_bedrock_guardrail_filters_latest_user_message_when_enabled():
256+
guardrail = BedrockGuardrail(
257+
guardrail_name="test-bedrock-guard",
258+
guardrailIdentifier="test-guard-id",
259+
guardrailVersion="DRAFT",
260+
experimental_use_latest_role_message_only=True,
261+
)
262+
263+
messages = [
264+
{"role": "system", "content": "rules"},
265+
{"role": "user", "content": "first question"},
266+
{"role": "assistant", "content": "response"},
267+
{"role": "user", "content": "latest question"},
268+
]
269+
270+
filter_result = guardrail._prepare_guardrail_messages_for_role(messages=messages)
271+
272+
assert filter_result.payload_messages is not None
273+
assert len(filter_result.payload_messages) == 1
274+
assert filter_result.payload_messages[0]["content"] == "latest question"
275+
assert filter_result.target_indices == [3]
276+
277+
masked_messages = guardrail._merge_filtered_messages(
278+
original_messages=filter_result.original_messages,
279+
updated_target_messages=[{"role": "user", "content": "[MASKED]"}],
280+
target_indices=filter_result.target_indices,
281+
)
282+
assert masked_messages[3]["content"] == "[MASKED]"

0 commit comments

Comments
 (0)