|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Tests for Guardrails AI integration - updated to match current implementation.""" |
| 17 | + |
| 18 | +import inspect |
| 19 | +from typing import Any, Dict |
| 20 | +from unittest.mock import Mock, patch |
| 21 | + |
| 22 | +import pytest |
| 23 | + |
| 24 | + |
| 25 | +class TestGuardrailsAIIntegration: |
| 26 | + """Test suite for Guardrails AI integration with current implementation.""" |
| 27 | + |
| 28 | + def test_module_imports_without_guardrails(self): |
| 29 | + """Test that modules can be imported even without guardrails package.""" |
| 30 | + from nemoguardrails.library.guardrails_ai.actions import ( |
| 31 | + _get_guard, |
| 32 | + guardrails_ai_validation_mapping, |
| 33 | + validate_guardrails_ai, |
| 34 | + ) |
| 35 | + from nemoguardrails.library.guardrails_ai.registry import VALIDATOR_REGISTRY |
| 36 | + |
| 37 | + assert callable(validate_guardrails_ai) |
| 38 | + assert callable(guardrails_ai_validation_mapping) |
| 39 | + assert isinstance(VALIDATOR_REGISTRY, dict) |
| 40 | + |
| 41 | + def test_validator_registry_structure(self): |
| 42 | + """Test that the validator registry has the expected structure.""" |
| 43 | + from nemoguardrails.library.guardrails_ai.registry import VALIDATOR_REGISTRY |
| 44 | + |
| 45 | + assert isinstance(VALIDATOR_REGISTRY, dict) |
| 46 | + assert len(VALIDATOR_REGISTRY) >= 6 |
| 47 | + |
| 48 | + expected_validators = [ |
| 49 | + "toxic_language", |
| 50 | + "detect_jailbreak", |
| 51 | + "guardrails_pii", |
| 52 | + "competitor_check", |
| 53 | + "restrict_to_topic", |
| 54 | + "provenance_llm", |
| 55 | + ] |
| 56 | + |
| 57 | + for validator in expected_validators: |
| 58 | + assert validator in VALIDATOR_REGISTRY |
| 59 | + validator_info = VALIDATOR_REGISTRY[validator] |
| 60 | + assert "module" in validator_info |
| 61 | + assert "class" in validator_info |
| 62 | + assert "hub_path" in validator_info |
| 63 | + assert "default_params" in validator_info |
| 64 | + assert isinstance(validator_info["default_params"], dict) |
| 65 | + |
| 66 | + def test_validation_mapping_function(self): |
| 67 | + """Test the validation mapping function with current interface.""" |
| 68 | + from nemoguardrails.library.guardrails_ai.actions import ( |
| 69 | + guardrails_ai_validation_mapping, |
| 70 | + ) |
| 71 | + |
| 72 | + mock_result = Mock() |
| 73 | + mock_result.validation_passed = True |
| 74 | + result1 = {"validation_result": mock_result} |
| 75 | + mapped1 = guardrails_ai_validation_mapping(result1) |
| 76 | + assert mapped1["valid"] is True |
| 77 | + assert mapped1["validation_result"] == mock_result |
| 78 | + |
| 79 | + mock_result2 = Mock() |
| 80 | + mock_result2.validation_passed = False |
| 81 | + result2 = {"validation_result": mock_result2} |
| 82 | + mapped2 = guardrails_ai_validation_mapping(result2) |
| 83 | + assert mapped2["valid"] is False |
| 84 | + |
| 85 | + result3 = {"validation_result": {"validation_passed": True}} |
| 86 | + mapped3 = guardrails_ai_validation_mapping(result3) |
| 87 | + assert mapped3["valid"] is True |
| 88 | + |
| 89 | + @patch("nemoguardrails.library.guardrails_ai.actions._get_guard") |
| 90 | + def test_validate_guardrails_ai_success(self, mock_get_guard): |
| 91 | + """Test successful validation with current interface.""" |
| 92 | + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai |
| 93 | + |
| 94 | + mock_guard = Mock() |
| 95 | + mock_validation_result = Mock() |
| 96 | + mock_validation_result.validation_passed = True |
| 97 | + mock_guard.validate.return_value = mock_validation_result |
| 98 | + mock_get_guard.return_value = mock_guard |
| 99 | + |
| 100 | + result = validate_guardrails_ai( |
| 101 | + validator_name="toxic_language", |
| 102 | + text="Hello, this is a safe message", |
| 103 | + threshold=0.5, |
| 104 | + ) |
| 105 | + |
| 106 | + assert "validation_result" in result |
| 107 | + assert result["validation_result"] == mock_validation_result |
| 108 | + mock_guard.validate.assert_called_once_with( |
| 109 | + "Hello, this is a safe message", metadata={} |
| 110 | + ) |
| 111 | + mock_get_guard.assert_called_once_with("toxic_language", threshold=0.5) |
| 112 | + |
| 113 | + @patch("nemoguardrails.library.guardrails_ai.actions._get_guard") |
| 114 | + def test_validate_guardrails_ai_with_metadata(self, mock_get_guard): |
| 115 | + """Test validation with metadata parameter.""" |
| 116 | + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai |
| 117 | + |
| 118 | + mock_guard = Mock() |
| 119 | + mock_validation_result = Mock() |
| 120 | + mock_validation_result.validation_passed = False |
| 121 | + mock_guard.validate.return_value = mock_validation_result |
| 122 | + mock_get_guard.return_value = mock_guard |
| 123 | + |
| 124 | + metadata = {"source": "user_input"} |
| 125 | + result = validate_guardrails_ai( |
| 126 | + validator_name="detect_jailbreak", |
| 127 | + text="Some text", |
| 128 | + metadata=metadata, |
| 129 | + threshold=0.8, |
| 130 | + ) |
| 131 | + |
| 132 | + assert "validation_result" in result |
| 133 | + assert result["validation_result"] == mock_validation_result |
| 134 | + mock_guard.validate.assert_called_once_with("Some text", metadata=metadata) |
| 135 | + mock_get_guard.assert_called_once_with("detect_jailbreak", threshold=0.8) |
| 136 | + |
| 137 | + @patch("nemoguardrails.library.guardrails_ai.actions._get_guard") |
| 138 | + def test_validate_guardrails_ai_error_handling(self, mock_get_guard): |
| 139 | + """Test error handling in validation.""" |
| 140 | + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai |
| 141 | + from nemoguardrails.library.guardrails_ai.errors import ( |
| 142 | + GuardrailsAIValidationError, |
| 143 | + ) |
| 144 | + |
| 145 | + mock_guard = Mock() |
| 146 | + mock_guard.validate.side_effect = Exception("Validation service error") |
| 147 | + mock_get_guard.return_value = mock_guard |
| 148 | + |
| 149 | + with pytest.raises(GuardrailsAIValidationError) as exc_info: |
| 150 | + validate_guardrails_ai(validator_name="toxic_language", text="Any text") |
| 151 | + |
| 152 | + assert "Validation failed" in str(exc_info.value) |
| 153 | + assert "Validation service error" in str(exc_info.value) |
| 154 | + |
| 155 | + @patch("nemoguardrails.library.guardrails_ai.actions._load_validator_class") |
| 156 | + @patch("nemoguardrails.library.guardrails_ai.actions.Guard") |
| 157 | + def test_get_guard_creates_and_caches(self, mock_guard_class, mock_load_validator): |
| 158 | + """Test that _get_guard creates and caches guards properly.""" |
| 159 | + from nemoguardrails.library.guardrails_ai.actions import _get_guard |
| 160 | + |
| 161 | + mock_validator_class = Mock() |
| 162 | + mock_validator_instance = Mock() |
| 163 | + mock_guard_instance = Mock() |
| 164 | + mock_guard = Mock() |
| 165 | + |
| 166 | + mock_load_validator.return_value = mock_validator_class |
| 167 | + mock_validator_class.return_value = mock_validator_instance |
| 168 | + mock_guard_class.return_value = mock_guard |
| 169 | + mock_guard.use.return_value = mock_guard_instance |
| 170 | + |
| 171 | + # clear cache |
| 172 | + import nemoguardrails.library.guardrails_ai.actions as actions |
| 173 | + |
| 174 | + actions._guard_cache.clear() |
| 175 | + |
| 176 | + # first call should create new guard |
| 177 | + result1 = _get_guard("toxic_language", threshold=0.5) |
| 178 | + |
| 179 | + assert result1 == mock_guard_instance |
| 180 | + mock_validator_class.assert_called_once_with(threshold=0.5, on_fail="noop") |
| 181 | + mock_guard.use.assert_called_once_with(mock_validator_instance) |
| 182 | + |
| 183 | + # reset mocks for second call |
| 184 | + mock_load_validator.reset_mock() |
| 185 | + mock_validator_class.reset_mock() |
| 186 | + mock_guard_class.reset_mock() |
| 187 | + |
| 188 | + # second call with same params should use cache |
| 189 | + result2 = _get_guard("toxic_language", threshold=0.5) |
| 190 | + |
| 191 | + assert result2 == mock_guard_instance |
| 192 | + # should not create new validator or guard |
| 193 | + mock_load_validator.assert_not_called() |
| 194 | + mock_validator_class.assert_not_called() |
| 195 | + mock_guard_class.assert_not_called() |
| 196 | + |
| 197 | + @patch("nemoguardrails.library.guardrails_ai.registry.get_validator_info") |
| 198 | + def test_load_validator_class_unknown_validator(self, mock_get_info): |
| 199 | + """Test error handling for unknown validators.""" |
| 200 | + from nemoguardrails.library.guardrails_ai.actions import _load_validator_class |
| 201 | + from nemoguardrails.library.guardrails_ai.errors import GuardrailsAIConfigError |
| 202 | + |
| 203 | + mock_get_info.side_effect = GuardrailsAIConfigError( |
| 204 | + "Unknown validator: unknown_validator" |
| 205 | + ) |
| 206 | + |
| 207 | + with pytest.raises(ImportError) as exc_info: |
| 208 | + _load_validator_class("unknown_validator") |
| 209 | + |
| 210 | + assert "Failed to load validator unknown_validator" in str(exc_info.value) |
| 211 | + |
| 212 | + def test_validate_guardrails_ai_signature(self): |
| 213 | + """Test that validate_guardrails_ai has the expected signature.""" |
| 214 | + from nemoguardrails.library.guardrails_ai.actions import validate_guardrails_ai |
| 215 | + |
| 216 | + sig = inspect.signature(validate_guardrails_ai) |
| 217 | + params = list(sig.parameters.keys()) |
| 218 | + |
| 219 | + assert "validator_name" in params |
| 220 | + assert "text" in params |
| 221 | + assert any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()) |
| 222 | + |
| 223 | + def test_ValidationResult_type(self): |
| 224 | + """Test that ValidationResult type is properly defined.""" |
| 225 | + from nemoguardrails.library.guardrails_ai.actions import ValidationResult |
| 226 | + |
| 227 | + # ValidationResult should be a type alias for Dict[str, Any] |
| 228 | + assert ValidationResult == Dict[str, Any] |
| 229 | + |
| 230 | + @patch("nemoguardrails.library.guardrails_ai.actions._load_validator_class") |
| 231 | + @patch("nemoguardrails.library.guardrails_ai.actions.Guard") |
| 232 | + def test_guard_cache_key_generation(self, mock_guard_class, mock_load): |
| 233 | + """Test that guard cache keys are generated correctly for different parameter combinations.""" |
| 234 | + from nemoguardrails.library.guardrails_ai.actions import _get_guard |
| 235 | + |
| 236 | + mock_validator_class = Mock() |
| 237 | + mock_guard_instance = Mock() |
| 238 | + mock_guard = Mock() |
| 239 | + |
| 240 | + mock_load.return_value = mock_validator_class |
| 241 | + mock_guard_class.return_value = mock_guard |
| 242 | + mock_guard.use.return_value = mock_guard_instance |
| 243 | + |
| 244 | + import nemoguardrails.library.guardrails_ai.actions as actions |
| 245 | + |
| 246 | + actions._guard_cache.clear() |
| 247 | + |
| 248 | + # create guards with different parameters |
| 249 | + _get_guard("toxic_language", threshold=0.5) |
| 250 | + _get_guard("toxic_language", threshold=0.8) |
| 251 | + _get_guard("detect_jailbreak", threshold=0.5) |
| 252 | + |
| 253 | + assert len(actions._guard_cache) == 3 |
0 commit comments