Skip to content

Commit eab5154

Browse files
feat: add XBG rails
1 parent a5a0518 commit eab5154

File tree

11 files changed

+397
-0
lines changed

11 files changed

+397
-0
lines changed

docs/user-guides/community/xgb.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# XGB Detectors Integration
2+
3+
XGB Detectors utilizes [XGBoost machine learning models](https://xgboost.readthedocs.io/en/stable/tutorials/model.html) to detect harmful content in data. Currently, only
4+
the spam text detector, trained by the [Red Hat TrustyAI team](https://github.com/trustyai-explainability), is available for guardrailing use.
5+
6+
## Setup
7+
8+
Update your `config.yaml` file to include XGB detectors:
9+
10+
**Spam detection config**
11+
```
12+
rails:
13+
config:
14+
xgb:
15+
input:
16+
detectors:
17+
- SPAM
18+
output:
19+
detectors:
20+
- SPAM
21+
input:
22+
flows:
23+
- xgb detect on input
24+
output:
25+
flows:
26+
- xgb detect on output
27+
```
28+
The detection flow will not let the input and output text pass if spam is detected.
29+
30+
## Usage
31+
32+
Once configured, the XGB Guardrails integration will automatically:
33+
34+
1. Detect spam in inputs to the LLM
35+
3. Detect spam in outputs from the LLM
36+
37+
## Error Handling
38+
39+
If the inference request to the XGB spam model fails, the system will assume spam is present as a precautionary measure.
40+
41+
## Notes
42+
43+
For more information on TrustyAI and its projects, please visit the TrustyAI [documentation](https://trustyai.org/docs/main/main).

examples/configs/xgb/config.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
models:
2+
- type: main
3+
engine: hf_pipeline_flan
4+
model: "google-bert/bert-base-cased"
5+
rails:
6+
config:
7+
xgb:
8+
input:
9+
detectors:
10+
- SPAM
11+
output:
12+
detectors:
13+
- SPAM
14+
input:
15+
flows:
16+
- xgb detect on input
17+
output:
18+
flows:
19+
- xgb detect on output
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

nemoguardrails/library/xgb/actions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from nemoguardrails.actions import action
17+
from nemoguardrails.library.xgb.inference import xgb_inference
18+
from nemoguardrails.rails.llm.config import RailsConfig
19+
20+
21+
@action()
22+
async def xgb_detect(
23+
source: str,
24+
text: str,
25+
config: RailsConfig,
26+
**kwargs,
27+
):
28+
xgb_config = getattr(config.rails.config, "xgb")
29+
source_config = getattr(xgb_config, source)
30+
31+
enabled_detectors = getattr(source_config, "detectors", None)
32+
if enabled_detectors is None:
33+
raise ValueError(
34+
f"Could not find 'detectors' in source_config: {source_config}"
35+
)
36+
valid_detectors = ["SPAM"]
37+
for detector in enabled_detectors:
38+
if detector not in valid_detectors:
39+
raise ValueError(
40+
f"XGB detectors can only be defined in the following detectors: {valid_detectors}. "
41+
f"The current detector, '{detector}' is not allowed."
42+
)
43+
44+
valid_sources = ["input", "output"]
45+
if source not in valid_sources:
46+
raise ValueError(
47+
f"XGB detectors can only be defined in the following flows: {valid_sources}. "
48+
f"The current flow, '{source} is not allowed."
49+
)
50+
51+
xgb_response = xgb_inference(
52+
text,
53+
enabled_detectors,
54+
)
55+
56+
return xgb_response

nemoguardrails/library/xgb/flows.co

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#### XGB DETECTION RAILS ####
2+
3+
# INPUT RAILS
4+
5+
flow xgb detect on input
6+
"""Check if the user content has harmful content"
7+
$detection = await XGBDetectAction(source="input", text=$user_message)
8+
9+
if $detection
10+
bot inform answer unknown
11+
abort
12+
13+
# OUTPUT RAILS
14+
15+
flow xgb detect on output
16+
"""Check if the bot output has harmful content"
17+
$detection = await XGBDetectAction(source="output", text=$bot_message)
18+
19+
if $detection
20+
bot inform answer unknown
21+
abort
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#### XGB DETECTION RAILS ####
2+
3+
# INPUT RAILS
4+
5+
define subflow xgb detect on input
6+
$detection = execute xgb_detect(source="input", text=$user_message)
7+
8+
if $detection
9+
bot inform answer unknown
10+
stop
11+
12+
# OUTPUT RAILS
13+
define subflow xgb detect on output
14+
$detection = execute xgb_detect(source="output", text=$user_message)
15+
16+
if $detection
17+
bot inform answer unknown
18+
stop
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import pickle
18+
from typing import List
19+
20+
log = logging.getLogger(__name__)
21+
MODEL_REGISTRY = {
22+
"SPAM": {
23+
"model_path": "nemoguardrails/library/xgb/model_artifacts/model.pkl",
24+
"vectorizer_path": "nemoguardrails/library/xgb/model_artifacts/vectorizer.pkl",
25+
}
26+
}
27+
28+
29+
def xgb_inference(text: str, enabled_detectors: List[str]):
30+
detections = []
31+
for detector in enabled_detectors:
32+
model_info = MODEL_REGISTRY.get(detector)
33+
if not model_info:
34+
raise ValueError(
35+
f"XGB detector '{detector}' is not configured in the MODEL_REGISTRY."
36+
)
37+
model_path = model_info["model_path"]
38+
vectorizer_path = model_info["vectorizer_path"]
39+
with open(model_path, "rb") as f:
40+
model = pickle.load(f)
41+
with open(vectorizer_path, "rb") as f:
42+
vectorizer = pickle.load(f)
43+
44+
try:
45+
X_vec = vectorizer.transform([text])
46+
prediction = model.predict(X_vec)[0]
47+
probability = model.predict_proba(X_vec)[0]
48+
49+
is_safe = prediction == 0
50+
confidence = max(probability)
51+
52+
detections.append(
53+
{
54+
"allowed": bool(is_safe),
55+
"score": float(confidence),
56+
"prediction": "safe" if is_safe else detector,
57+
}
58+
)
59+
60+
except Exception as e:
61+
raise ValueError(
62+
f"Error during XGBoost inference for detector '{detector}': {e}"
63+
)
64+
return detections
Binary file not shown.
Binary file not shown.

nemoguardrails/rails/llm/config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,28 @@ class FiddlerGuardrails(BaseModel):
293293
)
294294

295295

296+
class XGBDetectors(BaseModel):
297+
"""Configuration for XGBoost detectors."""
298+
299+
detectors: List[str] = Field(
300+
default_factory=list,
301+
description="The list of detectors to use.",
302+
)
303+
304+
305+
class XGBDetection(BaseModel):
306+
"""Configuration for XGBoost detectors."""
307+
308+
input: Optional[XGBDetectors] = Field(
309+
default_factory=XGBDetectors,
310+
description="XGBoost configuration for an Input Guardrail",
311+
)
312+
output: Optional[XGBDetectors] = Field(
313+
default_factory=XGBDetectors,
314+
description="XGBoost configuration for an Output Guardrail",
315+
)
316+
317+
296318
class MessageTemplate(BaseModel):
297319
"""Template for a message structure."""
298320

@@ -805,6 +827,11 @@ class RailsConfigData(BaseModel):
805827
description="Configuration for Clavata.",
806828
)
807829

830+
xgb: Optional[XGBDetection] = Field(
831+
default_factory=XGBDetection,
832+
description="Configuration for XGBoost Guardrails.",
833+
)
834+
808835

809836
class Rails(BaseModel):
810837
"""Configuration of specific rails."""

0 commit comments

Comments
 (0)