diff --git a/docs/user-guides/community/xgb.md b/docs/user-guides/community/xgb.md new file mode 100644 index 000000000..c14658127 --- /dev/null +++ b/docs/user-guides/community/xgb.md @@ -0,0 +1,43 @@ +# XGB Detectors Integration + +XGB Detectors utilizes [XGBoost machine learning models](https://xgboost.readthedocs.io/en/stable/tutorials/model.html) to detect harmful content in data. Currently, only +the spam text detector, trained by the [Red Hat TrustyAI team](https://github.com/trustyai-explainability), is available for guardrailing use. + +## Setup + +Update your `config.yaml` file to include XGB detectors: + +**Spam detection config** +``` +rails: + config: + xgb: + input: + detectors: + - SPAM + output: + detectors: + - SPAM + input: + flows: + - xgb detect on input + output: + flows: + - xgb detect on output +``` +The detection flow will not let the input and output text pass if spam is detected. + +## Usage + +Once configured, the XGB Guardrails integration will automatically: + +1. Detect spam in inputs to the LLM +3. Detect spam in outputs from the LLM + +## Error Handling + +If the inference request to the XGB spam model fails, the system will assume spam is present as a precautionary measure. + +## Notes + +For more information on TrustyAI and its projects, please visit the TrustyAI [documentation](https://trustyai.org/docs/main/main). diff --git a/examples/configs/xgb/config.yml b/examples/configs/xgb/config.yml new file mode 100644 index 000000000..72043983f --- /dev/null +++ b/examples/configs/xgb/config.yml @@ -0,0 +1,19 @@ +models: + - type: main + engine: hf_pipeline_gpt2 + model: "openai-community/gpt2" +rails: + config: + xgb: + input: + detectors: + - SPAM + output: + detectors: + - SPAM + input: + flows: + - xgb detect on input + output: + flows: + - xgb detect on output diff --git a/nemoguardrails/library/xgb/__init__.py b/nemoguardrails/library/xgb/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/library/xgb/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/library/xgb/actions.py b/nemoguardrails/library/xgb/actions.py new file mode 100644 index 000000000..515c02074 --- /dev/null +++ b/nemoguardrails/library/xgb/actions.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemoguardrails.actions import action +from nemoguardrails.library.xgb.inference import xgb_inference +from nemoguardrails.rails.llm.config import RailsConfig + + +@action() +async def xgb_detect( + source: str, + text: str, + config: RailsConfig, + **kwargs, +): + xgb_config = getattr(config.rails.config, "xgb") + source_config = getattr(xgb_config, source) + + enabled_detectors = getattr(source_config, "detectors", None) + if enabled_detectors is None: + raise ValueError( + f"Could not find 'detectors' in source_config: {source_config}" + ) + valid_detectors = ["SPAM"] + for detector in enabled_detectors: + if detector not in valid_detectors: + raise ValueError( + f"XGB detectors can only be defined in the following detectors: {valid_detectors}. " + f"The current detector, '{detector}' is not allowed." + ) + + valid_sources = ["input", "output"] + if source not in valid_sources: + raise ValueError( + f"XGB detectors can only be defined in the following flows: {valid_sources}. " + f"The current flow, '{source} is not allowed." + ) + + xgb_response = xgb_inference( + text, + enabled_detectors, + ) + + return xgb_response diff --git a/nemoguardrails/library/xgb/flows.co b/nemoguardrails/library/xgb/flows.co new file mode 100644 index 000000000..772968930 --- /dev/null +++ b/nemoguardrails/library/xgb/flows.co @@ -0,0 +1,21 @@ +#### XGB DETECTION RAILS #### + +# INPUT RAILS + +flow xgb detect on input +"""Check if the user content has harmful content" + $detection = await XGBDetectAction(source="input", text=$user_message) + + if $detection + bot inform answer unknown + abort + +# OUTPUT RAILS + +flow xgb detect on output +"""Check if the bot output has harmful content" + $detection = await XGBDetectAction(source="output", text=$bot_message) + + if $detection + bot inform answer unknown + abort diff --git a/nemoguardrails/library/xgb/flows.v1.co b/nemoguardrails/library/xgb/flows.v1.co new file mode 100644 index 000000000..f8629e939 --- /dev/null +++ b/nemoguardrails/library/xgb/flows.v1.co @@ -0,0 +1,18 @@ +#### XGB DETECTION RAILS #### + +# INPUT RAILS + +define subflow xgb detect on input + $detection = execute xgb_detect(source="input", text=$user_message) + + if $detection + bot inform answer unknown + stop + +# OUTPUT RAILS +define subflow xgb detect on output + $detection = execute xgb_detect(source="output", text=$user_message) + + if $detection + bot inform answer unknown + stop diff --git a/nemoguardrails/library/xgb/inference.py b/nemoguardrails/library/xgb/inference.py new file mode 100644 index 000000000..2afe32165 --- /dev/null +++ b/nemoguardrails/library/xgb/inference.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import pickle +from typing import List + +log = logging.getLogger(__name__) +MODEL_REGISTRY = { + "SPAM": { + "model_path": "nemoguardrails/library/xgb/model_artifacts/model.pkl", + "vectorizer_path": "nemoguardrails/library/xgb/model_artifacts/vectorizer.pkl", + } +} + + +def xgb_inference(text: str, enabled_detectors: List[str]): + detections = [] + for detector in enabled_detectors: + model_info = MODEL_REGISTRY.get(detector) + if not model_info: + raise ValueError( + f"XGB detector '{detector}' is not configured in the MODEL_REGISTRY." + ) + model_path = model_info["model_path"] + vectorizer_path = model_info["vectorizer_path"] + with open(model_path, "rb") as f: + model = pickle.load(f) + with open(vectorizer_path, "rb") as f: + vectorizer = pickle.load(f) + + try: + X_vec = vectorizer.transform([text]) + prediction = model.predict(X_vec)[0] + probability = model.predict_proba(X_vec)[0] + + is_safe = prediction == 0 + confidence = max(probability) + + detections.append( + { + "allowed": bool(is_safe), + "score": float(confidence), + "prediction": "safe" if is_safe else detector, + } + ) + + except Exception as e: + raise ValueError( + f"Error during XGBoost inference for detector '{detector}': {e}" + ) + return detections diff --git a/nemoguardrails/library/xgb/model_artifacts/model.pkl b/nemoguardrails/library/xgb/model_artifacts/model.pkl new file mode 100644 index 000000000..ec8c224fa Binary files /dev/null and b/nemoguardrails/library/xgb/model_artifacts/model.pkl differ diff --git a/nemoguardrails/library/xgb/model_artifacts/vectorizer.pkl b/nemoguardrails/library/xgb/model_artifacts/vectorizer.pkl new file mode 100644 index 000000000..2827707f1 Binary files /dev/null and b/nemoguardrails/library/xgb/model_artifacts/vectorizer.pkl differ diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index ffdd10220..cc7d3d617 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -293,6 +293,28 @@ class FiddlerGuardrails(BaseModel): ) +class XGBDetectors(BaseModel): + """Configuration for XGBoost detectors.""" + + detectors: List[str] = Field( + default_factory=list, + description="The list of detectors to use.", + ) + + +class XGBDetection(BaseModel): + """Configuration for XGBoost detectors.""" + + input: Optional[XGBDetectors] = Field( + default_factory=XGBDetectors, + description="XGBoost configuration for an Input Guardrail", + ) + output: Optional[XGBDetectors] = Field( + default_factory=XGBDetectors, + description="XGBoost configuration for an Output Guardrail", + ) + + class MessageTemplate(BaseModel): """Template for a message structure.""" @@ -805,6 +827,11 @@ class RailsConfigData(BaseModel): description="Configuration for Clavata.", ) + xgb: Optional[XGBDetection] = Field( + default_factory=XGBDetection, + description="Configuration for XGBoost Guardrails.", + ) + class Rails(BaseModel): """Configuration of specific rails.""" diff --git a/pyproject.toml b/pyproject.toml index 78908c7af..cc44899ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ repository = "https://github.com/NVIDIA/NeMo-Guardrails" nemoguardrails = "nemoguardrails.__main__:app" [tool.poetry.dependencies] -python = ">=3.9,!=3.9.7,<3.14" +python = ">=3.10,!=3.9.7,<3.14" aiohttp = ">=3.10.11" annoy = ">=1.17.3" fastapi = ">=0.103.0," @@ -101,6 +101,11 @@ google-cloud-language = { version = ">=2.14.0", optional = true } # jailbreak injection yara-python = { version = "^4.5.1", optional = true } +# xgb +xgboost = "^3.0.2" +scikit-learn = "^1.7.1" +huggingface-hub = "^0.34.3" + [tool.poetry.extras] sdd = ["presidio-analyzer", "presidio-anonymizer"] eval = ["tqdm", "numpy", "streamlit", "tornado"] diff --git a/tests/test_xgb_detection.py b/tests/test_xgb_detection.py new file mode 100644 index 000000000..1713e3d38 --- /dev/null +++ b/tests/test_xgb_detection.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.actions.actions import ActionResult, action +from tests.utils import TestChat + + +@action() +def retrieve_relevant_chunks(): + context_updates = {"relevant_chunks": "Mock retrieve context"} + return ActionResult( + return_value=context_updates["relevant_chunks"], + context_updates=context_updates, + ) + + +COLANG_CONTENT = """ + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." +""" + + +@pytest.mark.unit +def test_xgb_spam_detection_no_active_spam_detection(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + xgb: + input: + detectors: + - SPAM + input: + flows: + - xgb detect on input + """, + colang_content=COLANG_CONTENT, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! Nice to meet you."', + ], + ) + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat >> "Hi!" + chat << "Hi! Nice to meet you." + + +@pytest.mark.unit +def test_xgb_spam_detection_input(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + xgb: + input: + detectors: + - SPAM + input: + flows: + - xgb detect on input + """, + colang_content=COLANG_CONTENT, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! Nice to meet you"', + ], + ) + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + chat >> "Hi! GENT! We are trying to contact you. Last weekends draw shows that you won a £1000 prize GUARANTEED. Call 09064012160.Claim Code K52. Valid 12hrs only. 150ppm!" + + chat << "I can't answer that." + + +@pytest.mark.unit +def test_xgb_spam_detection_output(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + xgb: + output: + detectors: + - SPAM + output: + flows: + - xgb detect on output + """, + colang_content=COLANG_CONTENT, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + """ 'GENT! We are trying to contact you. Last weekends draw shows that you won a £1000 prize GUARANTEED. + Call 09064012160.Claim Code K52. Valid 12hrs only. 150ppm!' """, + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat >> "Hi!" + chat << "I can't answer that."