Skip to content

Commit a12dc8a

Browse files
committed
test(guardrails-ai): add integration and e2e test suite
1 parent 09c83ce commit a12dc8a

File tree

5 files changed

+1249
-2
lines changed

5 files changed

+1249
-2
lines changed

nemoguardrails/library/guardrails_ai/errors.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from guardrails.errors import ValidationError
16+
try:
17+
from guardrails.errors import ValidationError
1718

18-
GuardrailsAIValidationError = ValidationError
19+
GuardrailsAIValidationError = ValidationError
20+
except ImportError:
21+
# create a fallback error class when guardrails is not installed
22+
class GuardrailsAIValidationError(Exception):
23+
"""Fallback validation error when guardrails package is not available."""
24+
25+
pass
1926

2027

2128
class GuardrailsAIError(Exception):

tests/test_guardrails_ai_actions.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for Guardrails AI integration - updated to match current implementation."""
17+
18+
import inspect
19+
from typing import Any, Dict
20+
from unittest.mock import Mock, patch
21+
22+
import pytest
23+
24+
25+
class TestGuardrailsAIIntegration:
26+
"""Test suite for Guardrails AI integration with current implementation."""
27+
28+
def test_module_imports_without_guardrails(self):
29+
"""Test that modules can be imported even without guardrails package."""
30+
from nemoguardrails.library.guardrails_ai.actions import (
31+
_get_guard,
32+
guardrails_ai_validation_mapping,
33+
validate_guardrails_ai,
34+
)
35+
from nemoguardrails.library.guardrails_ai.registry import VALIDATOR_REGISTRY
36+
37+
assert callable(validate_guardrails_ai)
38+
assert callable(guardrails_ai_validation_mapping)
39+
assert isinstance(VALIDATOR_REGISTRY, dict)
40+
41+
def test_validator_registry_structure(self):
42+
"""Test that the validator registry has the expected structure."""
43+
from nemoguardrails.library.guardrails_ai.registry import VALIDATOR_REGISTRY
44+
45+
assert isinstance(VALIDATOR_REGISTRY, dict)
46+
assert len(VALIDATOR_REGISTRY) >= 6
47+
48+
expected_validators = [
49+
"toxic_language",
50+
"detect_jailbreak",
51+
"guardrails_pii",
52+
"competitor_check",
53+
"restrict_to_topic",
54+
"provenance_llm",
55+
]
56+
57+
for validator in expected_validators:
58+
assert validator in VALIDATOR_REGISTRY
59+
validator_info = VALIDATOR_REGISTRY[validator]
60+
assert "module" in validator_info
61+
assert "class" in validator_info
62+
assert "hub_path" in validator_info
63+
assert "default_params" in validator_info
64+
assert isinstance(validator_info["default_params"], dict)
65+
66+
def test_validation_mapping_function(self):
67+
"""Test the validation mapping function with current interface."""
68+
from nemoguardrails.library.guardrails_ai.actions import (
69+
guardrails_ai_validation_mapping,
70+
)
71+
72+
mock_result = Mock()
73+
mock_result.validation_passed = True
74+
result1 = {"validation_result": mock_result}
75+
mapped1 = guardrails_ai_validation_mapping(result1)
76+
assert mapped1["valid"] is True
77+
assert mapped1["validation_result"] == mock_result
78+
79+
mock_result2 = Mock()
80+
mock_result2.validation_passed = False
81+
result2 = {"validation_result": mock_result2}
82+
mapped2 = guardrails_ai_validation_mapping(result2)
83+
assert mapped2["valid"] is False
84+
85+
result3 = {"validation_result": {"validation_passed": True}}
86+
mapped3 = guardrails_ai_validation_mapping(result3)
87+
assert mapped3["valid"] is True
88+
89+
@patch("nemoguardrails.library.guardrails_ai.actions._get_guard")
90+
def test_validate_guardrails_ai_success(self, mock_get_guard):
91+
"""Test successful validation with current interface."""
92+
from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai
93+
94+
mock_guard = Mock()
95+
mock_validation_result = Mock()
96+
mock_validation_result.validation_passed = True
97+
mock_guard.validate.return_value = mock_validation_result
98+
mock_get_guard.return_value = mock_guard
99+
100+
result = validate_guardrails_ai(
101+
validator_name="toxic_language",
102+
text="Hello, this is a safe message",
103+
threshold=0.5,
104+
)
105+
106+
assert "validation_result" in result
107+
assert result["validation_result"] == mock_validation_result
108+
mock_guard.validate.assert_called_once_with(
109+
"Hello, this is a safe message", metadata={}
110+
)
111+
mock_get_guard.assert_called_once_with("toxic_language", threshold=0.5)
112+
113+
@patch("nemoguardrails.library.guardrails_ai.actions._get_guard")
114+
def test_validate_guardrails_ai_with_metadata(self, mock_get_guard):
115+
"""Test validation with metadata parameter."""
116+
from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai
117+
118+
mock_guard = Mock()
119+
mock_validation_result = Mock()
120+
mock_validation_result.validation_passed = False
121+
mock_guard.validate.return_value = mock_validation_result
122+
mock_get_guard.return_value = mock_guard
123+
124+
metadata = {"source": "user_input"}
125+
result = validate_guardrails_ai(
126+
validator_name="detect_jailbreak",
127+
text="Some text",
128+
metadata=metadata,
129+
threshold=0.8,
130+
)
131+
132+
assert "validation_result" in result
133+
assert result["validation_result"] == mock_validation_result
134+
mock_guard.validate.assert_called_once_with("Some text", metadata=metadata)
135+
mock_get_guard.assert_called_once_with("detect_jailbreak", threshold=0.8)
136+
137+
@patch("nemoguardrails.library.guardrails_ai.actions._get_guard")
138+
def test_validate_guardrails_ai_error_handling(self, mock_get_guard):
139+
"""Test error handling in validation."""
140+
from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai
141+
from nemoguardrails.library.guardrails_ai.errors import (
142+
GuardrailsAIValidationError,
143+
)
144+
145+
mock_guard = Mock()
146+
mock_guard.validate.side_effect = Exception("Validation service error")
147+
mock_get_guard.return_value = mock_guard
148+
149+
with pytest.raises(GuardrailsAIValidationError) as exc_info:
150+
validate_guardrails_ai(validator_name="toxic_language", text="Any text")
151+
152+
assert "Validation failed" in str(exc_info.value)
153+
assert "Validation service error" in str(exc_info.value)
154+
155+
@patch("nemoguardrails.library.guardrails_ai.actions._load_validator_class")
156+
@patch("nemoguardrails.library.guardrails_ai.actions.Guard")
157+
def test_get_guard_creates_and_caches(self, mock_guard_class, mock_load_validator):
158+
"""Test that _get_guard creates and caches guards properly."""
159+
from nemoguardrails.library.guardrails_ai.actions import _get_guard
160+
161+
mock_validator_class = Mock()
162+
mock_validator_instance = Mock()
163+
mock_guard_instance = Mock()
164+
mock_guard = Mock()
165+
166+
mock_load_validator.return_value = mock_validator_class
167+
mock_validator_class.return_value = mock_validator_instance
168+
mock_guard_class.return_value = mock_guard
169+
mock_guard.use.return_value = mock_guard_instance
170+
171+
# clear cache
172+
import nemoguardrails.library.guardrails_ai.actions as actions
173+
174+
actions._guard_cache.clear()
175+
176+
# first call should create new guard
177+
result1 = _get_guard("toxic_language", threshold=0.5)
178+
179+
assert result1 == mock_guard_instance
180+
mock_validator_class.assert_called_once_with(threshold=0.5, on_fail="noop")
181+
mock_guard.use.assert_called_once_with(mock_validator_instance)
182+
183+
# reset mocks for second call
184+
mock_load_validator.reset_mock()
185+
mock_validator_class.reset_mock()
186+
mock_guard_class.reset_mock()
187+
188+
# second call with same params should use cache
189+
result2 = _get_guard("toxic_language", threshold=0.5)
190+
191+
assert result2 == mock_guard_instance
192+
# should not create new validator or guard
193+
mock_load_validator.assert_not_called()
194+
mock_validator_class.assert_not_called()
195+
mock_guard_class.assert_not_called()
196+
197+
@patch("nemoguardrails.library.guardrails_ai.registry.get_validator_info")
198+
def test_load_validator_class_unknown_validator(self, mock_get_info):
199+
"""Test error handling for unknown validators."""
200+
from nemoguardrails.library.guardrails_ai.actions import _load_validator_class
201+
from nemoguardrails.library.guardrails_ai.errors import GuardrailsAIConfigError
202+
203+
mock_get_info.side_effect = GuardrailsAIConfigError(
204+
"Unknown validator: unknown_validator"
205+
)
206+
207+
with pytest.raises(ImportError) as exc_info:
208+
_load_validator_class("unknown_validator")
209+
210+
assert "Failed to load validator unknown_validator" in str(exc_info.value)
211+
212+
def test_validate_guardrails_ai_signature(self):
213+
"""Test that validate_guardrails_ai has the expected signature."""
214+
from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai
215+
216+
sig = inspect.signature(validate_guardrails_ai)
217+
params = list(sig.parameters.keys())
218+
219+
assert "validator_name" in params
220+
assert "text" in params
221+
assert any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values())
222+
223+
def test_ValidationResult_type(self):
224+
"""Test that ValidationResult type is properly defined."""
225+
from nemoguardrails.library.guardrails_ai.actions import ValidationResult
226+
227+
# ValidationResult should be a type alias for Dict[str, Any]
228+
assert ValidationResult == Dict[str, Any]
229+
230+
@patch("nemoguardrails.library.guardrails_ai.actions._load_validator_class")
231+
@patch("nemoguardrails.library.guardrails_ai.actions.Guard")
232+
def test_guard_cache_key_generation(self, mock_guard_class, mock_load):
233+
"""Test that guard cache keys are generated correctly for different parameter combinations."""
234+
from nemoguardrails.library.guardrails_ai.actions import _get_guard
235+
236+
mock_validator_class = Mock()
237+
mock_guard_instance = Mock()
238+
mock_guard = Mock()
239+
240+
mock_load.return_value = mock_validator_class
241+
mock_guard_class.return_value = mock_guard
242+
mock_guard.use.return_value = mock_guard_instance
243+
244+
import nemoguardrails.library.guardrails_ai.actions as actions
245+
246+
actions._guard_cache.clear()
247+
248+
# create guards with different parameters
249+
_get_guard("toxic_language", threshold=0.5)
250+
_get_guard("toxic_language", threshold=0.8)
251+
_get_guard("detect_jailbreak", threshold=0.5)
252+
253+
assert len(actions._guard_cache) == 3

0 commit comments

Comments
 (0)