Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions ddtrace/appsec/_ai_guard/_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ddtrace.contrib.internal.trace_utils import unwrap
from ddtrace.contrib.internal.trace_utils import wrap
import ddtrace.internal.logger as ddlogger
from ddtrace.internal.utils import get_argument_value


logger = ddlogger.get_logger(__name__)
Expand Down Expand Up @@ -164,28 +165,57 @@ def _handle_agent_action_result(client: AIGuardClient, result, kwargs):


def _langchain_chatmodel_generate_before(client: AIGuardClient, message_lists):
from langchain_core.messages import HumanMessage

for messages in message_lists:
# only call evaluator when the last message is an actual user prompt
if len(messages) > 0 and isinstance(messages[-1], HumanMessage):
history = _convert_messages(messages)
prompt = history.pop(-1)
try:
if not client.evaluate_prompt(prompt["role"], prompt["content"], history=history): # type: ignore[typeddict-item]
return AIGuardAbortError()
except AIGuardAbortError as e:
return e
except Exception:
logger.debug("Failed to evaluate chat model prompt", exc_info=True)
result = _evaluate_langchain_messages(client, messages)
if result:
return result
return None


def _langchain_llm_generate_before(client: AIGuardClient, prompts):
for prompt in prompts:
result = _evaluate_langchain_prompt(client, prompt)
if result:
return result
return None


def _langchain_chatmodel_stream_before(client: AIGuardClient, instance, args, kwargs):
input_arg = get_argument_value(args, kwargs, 0, "input")
messages = instance._convert_input(input_arg).to_messages()
return _evaluate_langchain_messages(client, messages)


def _langchain_llm_stream_before(client: AIGuardClient, instance, args, kwargs):
input_arg = get_argument_value(args, kwargs, 0, "input")
prompt = instance._convert_input(input_arg).to_string()
return _evaluate_langchain_prompt(client, prompt)


def _evaluate_langchain_messages(client: AIGuardClient, messages):
from langchain_core.messages import HumanMessage

# only call evaluator when the last message is an actual user prompt
if len(messages) > 0 and isinstance(messages[-1], HumanMessage):
history = _convert_messages(messages)
prompt = history.pop(-1)
try:
if not client.evaluate_prompt("user", prompt):
role, content = (prompt["role"], prompt["content"]) # type: ignore[typeddict-item]
if not client.evaluate_prompt(role, content, history=history):
return AIGuardAbortError()
except AIGuardAbortError as e:
return e
except Exception:
logger.debug("Failed to evaluate llm prompt", exc_info=True)
logger.debug("Failed to evaluate chat model prompt", exc_info=True)
return None


def _evaluate_langchain_prompt(client: AIGuardClient, prompt):
try:
if not client.evaluate_prompt("user", prompt):
return AIGuardAbortError()
except AIGuardAbortError as e:
return e
except Exception:
logger.debug("Failed to evaluate llm prompt", exc_info=True)
return None
4 changes: 4 additions & 0 deletions ddtrace/appsec/_ai_guard/_listener.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import partial

from ddtrace.appsec._ai_guard._langchain import _langchain_chatmodel_generate_before
from ddtrace.appsec._ai_guard._langchain import _langchain_chatmodel_stream_before
from ddtrace.appsec._ai_guard._langchain import _langchain_llm_generate_before
from ddtrace.appsec._ai_guard._langchain import _langchain_llm_stream_before
from ddtrace.appsec._ai_guard._langchain import _langchain_patch
from ddtrace.appsec._ai_guard._langchain import _langchain_unpatch
from ddtrace.appsec.ai_guard import AIGuardClient
Expand All @@ -20,6 +22,8 @@ def _langchain_listen(client: AIGuardClient):

core.on("langchain.chatmodel.generate.before", partial(_langchain_chatmodel_generate_before, client))
core.on("langchain.chatmodel.agenerate.before", partial(_langchain_chatmodel_generate_before, client))
core.on("langchain.chatmodel.stream.before", partial(_langchain_chatmodel_stream_before, client))

core.on("langchain.llm.generate.before", partial(_langchain_llm_generate_before, client))
core.on("langchain.llm.agenerate.before", partial(_langchain_llm_generate_before, client))
core.on("langchain.llm.stream.before", partial(_langchain_llm_stream_before, client))
11 changes: 11 additions & 0 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ def traced_chat_stream(langchain, pin, func, instance, args, kwargs):
llm_provider = instance._llm_type
model = _extract_model_name(instance)

_raising_dispatch("langchain.chatmodel.stream.before", (instance, args, kwargs))

def _on_span_started(span: Span):
integration.record_instance(instance, span)

Expand Down Expand Up @@ -443,6 +445,15 @@ def traced_llm_stream(langchain, pin, func, instance, args, kwargs):
llm_provider = instance._llm_type
model = _extract_model_name(instance)

_raising_dispatch(
"langchain.llm.stream.before",
(
instance,
args,
kwargs,
),
)

def _on_span_start(span: Span):
integration.record_instance(instance, span)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
ai_guard: add evaluation support to streaming LangChain APIs
108 changes: 108 additions & 0 deletions tests/appsec/ai_guard/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,111 @@ def test_message_conversion():

assert result[5]["role"] == "assistant"
assert result[5]["content"] == "One plus one is two"


@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
def test_streamed_chat_sync_allow(mock_execute_request, langchain_openai, openai_url):
mock_execute_request.return_value = mock_evaluate_response("ALLOW")

model = langchain_openai.ChatOpenAI(base_url=openai_url)

for _ in model.stream(input="how can langsmith help with testing?"):
pass

mock_execute_request.assert_called_once()


@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
def test_streamed_chat_sync_block(mock_execute_request, langchain_openai, openai_url, decision):
mock_execute_request.return_value = mock_evaluate_response(decision)

model = langchain_openai.ChatOpenAI(base_url=openai_url)

with pytest.raises(AIGuardAbortError):
for _ in model.stream(input="how can langsmith help with testing?"):
pass

mock_execute_request.assert_called_once()


@pytest.mark.asyncio
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
async def test_streamed_chat_async_allow(mock_execute_request, langchain_openai, openai_url):
mock_execute_request.return_value = mock_evaluate_response("ALLOW")

model = langchain_openai.ChatOpenAI(base_url=openai_url)

async for _ in model.astream(input="how can langsmith help with testing?"):
pass

mock_execute_request.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
async def test_streamed_chat_async_block(mock_execute_request, langchain_openai, openai_url, decision):
mock_execute_request.return_value = mock_evaluate_response(decision)

model = langchain_openai.ChatOpenAI(base_url=openai_url)

with pytest.raises(AIGuardAbortError):
async for _ in model.astream(input="how can langsmith help with testing?"):
pass

mock_execute_request.assert_called_once()


@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
def test_streamed_llm_sync_allow(mock_execute_request, langchain_openai, openai_url):
mock_execute_request.return_value = mock_evaluate_response("ALLOW")

llm = langchain_openai.OpenAI(base_url=openai_url)

for _ in llm.stream(input="How do I write technical documentation?"):
pass

mock_execute_request.assert_called_once()


@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
async def test_streamed_llm_sync_block(mock_execute_request, langchain_openai, openai_url, decision):
mock_execute_request.return_value = mock_evaluate_response(decision)

llm = langchain_openai.OpenAI(base_url=openai_url)

with pytest.raises(AIGuardAbortError):
for _ in llm.stream(input="How do I write technical documentation?"):
pass

mock_execute_request.assert_called_once()


@pytest.mark.asyncio
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
async def test_streamed_llm_async_allow(mock_execute_request, langchain_openai, openai_url):
mock_execute_request.return_value = mock_evaluate_response("ALLOW")

llm = langchain_openai.OpenAI(base_url=openai_url)

async for _ in llm.astream(input="How do I write technical documentation?"):
pass

mock_execute_request.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.parametrize("decision", ["DENY", "ABORT"], ids=["deny", "abort"])
@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
async def test_streamed_llm_async_block(mock_execute_request, langchain_openai, openai_url, decision):
mock_execute_request.return_value = mock_evaluate_response(decision)

llm = langchain_openai.OpenAI(base_url=openai_url)

with pytest.raises(AIGuardAbortError):
async for _ in llm.astream(input="How do I write technical documentation?"):
pass

mock_execute_request.assert_called_once()
Loading