Skip to content

Commit a810c7e

Browse files
feat(ai_guard): add evaluation support to streaming LangChain APIs
1 parent 88816b2 commit a810c7e

File tree

5 files changed

+172
-15
lines changed

5 files changed

+172
-15
lines changed

ddtrace/appsec/_ai_guard/_langchain.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ddtrace.contrib.internal.trace_utils import unwrap
1414
from ddtrace.contrib.internal.trace_utils import wrap
1515
import ddtrace.internal.logger as ddlogger
16+
from ddtrace.internal.utils import get_argument_value
1617

1718

1819
logger = ddlogger.get_logger(__name__)
@@ -164,28 +165,57 @@ def _handle_agent_action_result(client: AIGuardClient, result, kwargs):
164165

165166

166167
def _langchain_chatmodel_generate_before(client: AIGuardClient, message_lists):
167-
from langchain_core.messages import HumanMessage
168-
169168
for messages in message_lists:
170-
# only call evaluator when the last message is an actual user prompt
171-
if len(messages) > 0 and isinstance(messages[-1], HumanMessage):
172-
history = _convert_messages(messages)
173-
prompt = history.pop(-1)
174-
try:
175-
if not client.evaluate_prompt(prompt["role"], prompt["content"], history=history): # type: ignore[typeddict-item]
176-
return AIGuardAbortError()
177-
except AIGuardAbortError as e:
178-
return e
179-
except Exception:
180-
logger.debug("Failed to evaluate chat model prompt", exc_info=True)
169+
result = _evaluate_langchain_messages(client, messages)
170+
if result:
171+
return result
172+
return None
181173

182174

183175
def _langchain_llm_generate_before(client: AIGuardClient, prompts):
184176
for prompt in prompts:
177+
result = _evaluate_langchain_prompt(client, prompt)
178+
if result:
179+
return result
180+
return None
181+
182+
183+
def _langchain_chatmodel_stream_before(client: AIGuardClient, instance, args, kwargs):
184+
input_arg = get_argument_value(args, kwargs, 0, "input")
185+
messages = instance._convert_input(input_arg).to_messages()
186+
return _evaluate_langchain_messages(client, messages)
187+
188+
189+
def _langchain_llm_stream_before(client: AIGuardClient, instance, args, kwargs):
190+
input_arg = get_argument_value(args, kwargs, 0, "input")
191+
prompt = instance._convert_input(input_arg).to_string()
192+
return _evaluate_langchain_prompt(client, prompt)
193+
194+
195+
def _evaluate_langchain_messages(client: AIGuardClient, messages):
196+
from langchain_core.messages import HumanMessage
197+
198+
# only call evaluator when the last message is an actual user prompt
199+
if len(messages) > 0 and isinstance(messages[-1], HumanMessage):
200+
history = _convert_messages(messages)
201+
prompt = history.pop(-1)
185202
try:
186-
if not client.evaluate_prompt("user", prompt):
203+
role, content = (prompt["role"], prompt["content"]) # type: ignore[typeddict-item]
204+
if not client.evaluate_prompt(role, content, history=history):
187205
return AIGuardAbortError()
188206
except AIGuardAbortError as e:
189207
return e
190208
except Exception:
191-
logger.debug("Failed to evaluate llm prompt", exc_info=True)
209+
logger.debug("Failed to evaluate chat model prompt", exc_info=True)
210+
return None
211+
212+
213+
def _evaluate_langchain_prompt(client: AIGuardClient, prompt):
214+
try:
215+
if not client.evaluate_prompt("user", prompt):
216+
return AIGuardAbortError()
217+
except AIGuardAbortError as e:
218+
return e
219+
except Exception:
220+
logger.debug("Failed to evaluate llm prompt", exc_info=True)
221+
return None

ddtrace/appsec/_ai_guard/_listener.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from functools import partial
22

33
from ddtrace.appsec._ai_guard._langchain import _langchain_chatmodel_generate_before
4+
from ddtrace.appsec._ai_guard._langchain import _langchain_chatmodel_stream_before
45
from ddtrace.appsec._ai_guard._langchain import _langchain_llm_generate_before
6+
from ddtrace.appsec._ai_guard._langchain import _langchain_llm_stream_before
57
from ddtrace.appsec._ai_guard._langchain import _langchain_patch
68
from ddtrace.appsec._ai_guard._langchain import _langchain_unpatch
79
from ddtrace.appsec.ai_guard import AIGuardClient
@@ -20,6 +22,8 @@ def _langchain_listen(client: AIGuardClient):
2022

2123
core.on("langchain.chatmodel.generate.before", partial(_langchain_chatmodel_generate_before, client))
2224
core.on("langchain.chatmodel.agenerate.before", partial(_langchain_chatmodel_generate_before, client))
25+
core.on("langchain.chatmodel.stream.before", partial(_langchain_chatmodel_stream_before, client))
2326

2427
core.on("langchain.llm.generate.before", partial(_langchain_llm_generate_before, client))
2528
core.on("langchain.llm.agenerate.before", partial(_langchain_llm_generate_before, client))
29+
core.on("langchain.llm.stream.before", partial(_langchain_llm_stream_before, client))

ddtrace/contrib/internal/langchain/patch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ def traced_chat_stream(langchain, pin, func, instance, args, kwargs):
412412
llm_provider = instance._llm_type
413413
model = _extract_model_name(instance)
414414

415+
_raising_dispatch("langchain.chatmodel.stream.before", (instance, args, kwargs))
416+
415417
def _on_span_started(span: Span):
416418
integration.record_instance(instance, span)
417419

@@ -443,6 +445,15 @@ def traced_llm_stream(langchain, pin, func, instance, args, kwargs):
443445
llm_provider = instance._llm_type
444446
model = _extract_model_name(instance)
445447

448+
_raising_dispatch(
449+
"langchain.llm.stream.before",
450+
(
451+
instance,
452+
args,
453+
kwargs,
454+
),
455+
)
456+
446457
def _on_span_start(span: Span):
447458
integration.record_instance(instance, span)
448459

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
ai_guard: add streaming support to AI Guard evaluations in LangChain

tests/appsec/ai_guard/langchain/test_langchain.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,111 @@ def test_message_conversion():
251251

252252
assert result[5]["role"] == "assistant"
253253
assert result[5]["content"] == "One plus one is two"
254+
255+
256+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
257+
def test_streamed_chat_sync_allow(mock_execute_request, langchain_openai, openai_url):
258+
mock_execute_request.return_value = mock_evaluate_response("ALLOW")
259+
260+
model = langchain_openai.ChatOpenAI(base_url=openai_url)
261+
262+
for _ in model.stream(input="how can langsmith help with testing?"):
263+
pass
264+
265+
mock_execute_request.assert_called_once()
266+
267+
268+
@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
269+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
270+
def test_streamed_chat_sync_block(mock_execute_request, langchain_openai, openai_url, decision):
271+
mock_execute_request.return_value = mock_evaluate_response(decision)
272+
273+
model = langchain_openai.ChatOpenAI(base_url=openai_url)
274+
275+
with pytest.raises(AIGuardAbortError):
276+
for _ in model.stream(input="how can langsmith help with testing?"):
277+
pass
278+
279+
mock_execute_request.assert_called_once()
280+
281+
282+
@pytest.mark.asyncio
283+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
284+
async def test_streamed_chat_async_allow(mock_execute_request, langchain_openai, openai_url):
285+
mock_execute_request.return_value = mock_evaluate_response("ALLOW")
286+
287+
model = langchain_openai.ChatOpenAI(base_url=openai_url)
288+
289+
async for _ in model.astream(input="how can langsmith help with testing?"):
290+
pass
291+
292+
mock_execute_request.assert_called_once()
293+
294+
295+
@pytest.mark.asyncio
296+
@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
297+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
298+
async def test_streamed_chat_async_block(mock_execute_request, langchain_openai, openai_url, decision):
299+
mock_execute_request.return_value = mock_evaluate_response(decision)
300+
301+
model = langchain_openai.ChatOpenAI(base_url=openai_url)
302+
303+
with pytest.raises(AIGuardAbortError):
304+
async for _ in model.astream(input="how can langsmith help with testing?"):
305+
pass
306+
307+
mock_execute_request.assert_called_once()
308+
309+
310+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
311+
def test_streamed_llm_sync_allow(mock_execute_request, langchain_openai, openai_url):
312+
mock_execute_request.return_value = mock_evaluate_response("ALLOW")
313+
314+
llm = langchain_openai.OpenAI(base_url=openai_url)
315+
316+
for _ in llm.stream(input="How do I write technical documentation?"):
317+
pass
318+
319+
mock_execute_request.assert_called_once()
320+
321+
322+
@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
323+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
324+
async def test_streamed_llm_sync_block(mock_execute_request, langchain_openai, openai_url, decision):
325+
mock_execute_request.return_value = mock_evaluate_response(decision)
326+
327+
llm = langchain_openai.OpenAI(base_url=openai_url)
328+
329+
with pytest.raises(AIGuardAbortError):
330+
for _ in llm.stream(input="How do I write technical documentation?"):
331+
pass
332+
333+
mock_execute_request.assert_called_once()
334+
335+
336+
@pytest.mark.asyncio
337+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
338+
async def test_streamed_llm_async_allow(mock_execute_request, langchain_openai, openai_url):
339+
mock_execute_request.return_value = mock_evaluate_response("ALLOW")
340+
341+
llm = langchain_openai.OpenAI(base_url=openai_url)
342+
343+
async for _ in llm.astream(input="How do I write technical documentation?"):
344+
pass
345+
346+
mock_execute_request.assert_called_once()
347+
348+
349+
@pytest.mark.asyncio
350+
@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
351+
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
352+
async def test_streamed_llm_async_block(mock_execute_request, langchain_openai, openai_url, decision):
353+
mock_execute_request.return_value = mock_evaluate_response(decision)
354+
355+
llm = langchain_openai.OpenAI(base_url=openai_url)
356+
357+
with pytest.raises(AIGuardAbortError):
358+
async for _ in llm.astream(input="How do I write technical documentation?"):
359+
pass
360+
361+
mock_execute_request.assert_called_once()

0 commit comments

Comments
 (0)