diff --git a/ddtrace/appsec/_ai_guard/_langchain.py b/ddtrace/appsec/_ai_guard/_langchain.py index 331128fcb9f..292ce0ed011 100644 --- a/ddtrace/appsec/_ai_guard/_langchain.py +++ b/ddtrace/appsec/_ai_guard/_langchain.py @@ -13,6 +13,7 @@ from ddtrace.contrib.internal.trace_utils import unwrap from ddtrace.contrib.internal.trace_utils import wrap import ddtrace.internal.logger as ddlogger +from ddtrace.internal.utils import get_argument_value logger = ddlogger.get_logger(__name__) @@ -164,28 +165,57 @@ def _handle_agent_action_result(client: AIGuardClient, result, kwargs): def _langchain_chatmodel_generate_before(client: AIGuardClient, message_lists): - from langchain_core.messages import HumanMessage - for messages in message_lists: - # only call evaluator when the last message is an actual user prompt - if len(messages) > 0 and isinstance(messages[-1], HumanMessage): - history = _convert_messages(messages) - prompt = history.pop(-1) - try: - if not client.evaluate_prompt(prompt["role"], prompt["content"], history=history): # type: ignore[typeddict-item] - return AIGuardAbortError() - except AIGuardAbortError as e: - return e - except Exception: - logger.debug("Failed to evaluate chat model prompt", exc_info=True) + result = _evaluate_langchain_messages(client, messages) + if result: + return result + return None def _langchain_llm_generate_before(client: AIGuardClient, prompts): for prompt in prompts: + result = _evaluate_langchain_prompt(client, prompt) + if result: + return result + return None + + +def _langchain_chatmodel_stream_before(client: AIGuardClient, instance, args, kwargs): + input_arg = get_argument_value(args, kwargs, 0, "input") + messages = instance._convert_input(input_arg).to_messages() + return _evaluate_langchain_messages(client, messages) + + +def _langchain_llm_stream_before(client: AIGuardClient, instance, args, kwargs): + input_arg = get_argument_value(args, kwargs, 0, "input") + prompt = instance._convert_input(input_arg).to_string() + return _evaluate_langchain_prompt(client, prompt) + + +def _evaluate_langchain_messages(client: AIGuardClient, messages): + from langchain_core.messages import HumanMessage + + # only call evaluator when the last message is an actual user prompt + if len(messages) > 0 and isinstance(messages[-1], HumanMessage): + history = _convert_messages(messages) + prompt = history.pop(-1) try: - if not client.evaluate_prompt("user", prompt): + role, content = (prompt["role"], prompt["content"]) # type: ignore[typeddict-item] + if not client.evaluate_prompt(role, content, history=history): return AIGuardAbortError() except AIGuardAbortError as e: return e except Exception: - logger.debug("Failed to evaluate llm prompt", exc_info=True) + logger.debug("Failed to evaluate chat model prompt", exc_info=True) + return None + + +def _evaluate_langchain_prompt(client: AIGuardClient, prompt): + try: + if not client.evaluate_prompt("user", prompt): + return AIGuardAbortError() + except AIGuardAbortError as e: + return e + except Exception: + logger.debug("Failed to evaluate llm prompt", exc_info=True) + return None diff --git a/ddtrace/appsec/_ai_guard/_listener.py b/ddtrace/appsec/_ai_guard/_listener.py index 3052ce60a2a..b383516d147 100644 --- a/ddtrace/appsec/_ai_guard/_listener.py +++ b/ddtrace/appsec/_ai_guard/_listener.py @@ -1,7 +1,9 @@ from functools import partial from ddtrace.appsec._ai_guard._langchain import _langchain_chatmodel_generate_before +from ddtrace.appsec._ai_guard._langchain import _langchain_chatmodel_stream_before from ddtrace.appsec._ai_guard._langchain import _langchain_llm_generate_before +from ddtrace.appsec._ai_guard._langchain import _langchain_llm_stream_before from ddtrace.appsec._ai_guard._langchain import _langchain_patch from ddtrace.appsec._ai_guard._langchain import _langchain_unpatch from ddtrace.appsec.ai_guard import AIGuardClient @@ -20,6 +22,8 @@ def _langchain_listen(client: AIGuardClient): core.on("langchain.chatmodel.generate.before", partial(_langchain_chatmodel_generate_before, client)) core.on("langchain.chatmodel.agenerate.before", partial(_langchain_chatmodel_generate_before, client)) + core.on("langchain.chatmodel.stream.before", partial(_langchain_chatmodel_stream_before, client)) core.on("langchain.llm.generate.before", partial(_langchain_llm_generate_before, client)) core.on("langchain.llm.agenerate.before", partial(_langchain_llm_generate_before, client)) + core.on("langchain.llm.stream.before", partial(_langchain_llm_stream_before, client)) diff --git a/ddtrace/contrib/internal/langchain/patch.py b/ddtrace/contrib/internal/langchain/patch.py index 73c869688c9..facc9a5659b 100644 --- a/ddtrace/contrib/internal/langchain/patch.py +++ b/ddtrace/contrib/internal/langchain/patch.py @@ -412,6 +412,8 @@ def traced_chat_stream(langchain, pin, func, instance, args, kwargs): llm_provider = instance._llm_type model = _extract_model_name(instance) + _raising_dispatch("langchain.chatmodel.stream.before", (instance, args, kwargs)) + def _on_span_started(span: Span): integration.record_instance(instance, span) @@ -443,6 +445,15 @@ def traced_llm_stream(langchain, pin, func, instance, args, kwargs): llm_provider = instance._llm_type model = _extract_model_name(instance) + _raising_dispatch( + "langchain.llm.stream.before", + ( + instance, + args, + kwargs, + ), + ) + def _on_span_start(span: Span): integration.record_instance(instance, span) diff --git a/releasenotes/notes/Add-streaming-support-to-AI-Guard-evaluations-in-LangChain-c1286b0c00cfa354.yaml b/releasenotes/notes/Add-streaming-support-to-AI-Guard-evaluations-in-LangChain-c1286b0c00cfa354.yaml new file mode 100644 index 00000000000..a28c188d340 --- /dev/null +++ b/releasenotes/notes/Add-streaming-support-to-AI-Guard-evaluations-in-LangChain-c1286b0c00cfa354.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + ai_guard: add evaluation support to streaming LangChain APIs diff --git a/tests/appsec/ai_guard/langchain/test_langchain.py b/tests/appsec/ai_guard/langchain/test_langchain.py index b9d2d689554..015cd2699d9 100644 --- a/tests/appsec/ai_guard/langchain/test_langchain.py +++ b/tests/appsec/ai_guard/langchain/test_langchain.py @@ -251,3 +251,111 @@ def test_message_conversion(): assert result[5]["role"] == "assistant" assert result[5]["content"] == "One plus one is two" + + +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +def test_streamed_chat_sync_allow(mock_execute_request, langchain_openai, openai_url): + mock_execute_request.return_value = mock_evaluate_response("ALLOW") + + model = langchain_openai.ChatOpenAI(base_url=openai_url) + + for _ in model.stream(input="how can langsmith help with testing?"): + pass + + mock_execute_request.assert_called_once() + + +@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"]) +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +def test_streamed_chat_sync_block(mock_execute_request, langchain_openai, openai_url, decision): + mock_execute_request.return_value = mock_evaluate_response(decision) + + model = langchain_openai.ChatOpenAI(base_url=openai_url) + + with pytest.raises(AIGuardAbortError): + for _ in model.stream(input="how can langsmith help with testing?"): + pass + + mock_execute_request.assert_called_once() + + +@pytest.mark.asyncio +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +async def test_streamed_chat_async_allow(mock_execute_request, langchain_openai, openai_url): + mock_execute_request.return_value = mock_evaluate_response("ALLOW") + + model = langchain_openai.ChatOpenAI(base_url=openai_url) + + async for _ in model.astream(input="how can langsmith help with testing?"): + pass + + mock_execute_request.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"]) +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +async def test_streamed_chat_async_block(mock_execute_request, langchain_openai, openai_url, decision): + mock_execute_request.return_value = mock_evaluate_response(decision) + + model = langchain_openai.ChatOpenAI(base_url=openai_url) + + with pytest.raises(AIGuardAbortError): + async for _ in model.astream(input="how can langsmith help with testing?"): + pass + + mock_execute_request.assert_called_once() + + +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +def test_streamed_llm_sync_allow(mock_execute_request, langchain_openai, openai_url): + mock_execute_request.return_value = mock_evaluate_response("ALLOW") + + llm = langchain_openai.OpenAI(base_url=openai_url) + + for _ in llm.stream(input="How do I write technical documentation?"): + pass + + mock_execute_request.assert_called_once() + + +@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"]) +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +async def test_streamed_llm_sync_block(mock_execute_request, langchain_openai, openai_url, decision): + mock_execute_request.return_value = mock_evaluate_response(decision) + + llm = langchain_openai.OpenAI(base_url=openai_url) + + with pytest.raises(AIGuardAbortError): + for _ in llm.stream(input="How do I write technical documentation?"): + pass + + mock_execute_request.assert_called_once() + + +@pytest.mark.asyncio +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +async def test_streamed_llm_async_allow(mock_execute_request, langchain_openai, openai_url): + mock_execute_request.return_value = mock_evaluate_response("ALLOW") + + llm = langchain_openai.OpenAI(base_url=openai_url) + + async for _ in llm.astream(input="How do I write technical documentation?"): + pass + + mock_execute_request.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"]) +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +async def test_streamed_llm_async_block(mock_execute_request, langchain_openai, openai_url, decision): + mock_execute_request.return_value = mock_evaluate_response(decision) + + llm = langchain_openai.OpenAI(base_url=openai_url) + + with pytest.raises(AIGuardAbortError): + async for _ in llm.astream(input="How do I write technical documentation?"): + pass + + mock_execute_request.assert_called_once()