Skip to content

Commit 7519ab4

Browse files
authored
feat: Code scanner Provider impl for moderations api (#3100)
# What does this PR do? Add CodeScanner implementations ## Test Plan `SAFETY_MODEL=CodeScanner LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` This PR need to land after this #3098
1 parent 27d6bec commit 7519ab4

File tree

9 files changed

+144
-24
lines changed

9 files changed

+144
-24
lines changed

llama_stack/core/routers/safety.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66

77
from typing import Any
88

9-
from llama_stack.apis.inference import (
10-
Message,
11-
)
9+
from llama_stack.apis.inference import Message
1210
from llama_stack.apis.safety import RunShieldResponse, Safety
1311
from llama_stack.apis.safety.safety import ModerationObject
1412
from llama_stack.apis.shields import Shield
@@ -68,6 +66,7 @@ async def get_shield_id(self, model: str) -> str:
6866
list_shields_response = await self.routing_table.list_shields()
6967

7068
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
69+
7170
if not matches:
7271
raise ValueError(f"No shield associated with provider_resource id {model}")
7372
if len(matches) > 1:

llama_stack/distributions/ci-tests/build.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ distribution_spec:
2828
- provider_type: inline::localfs
2929
safety:
3030
- provider_type: inline::llama-guard
31+
- provider_type: inline::code-scanner
3132
agents:
3233
- provider_type: inline::meta-reference
3334
telemetry:

llama_stack/distributions/ci-tests/run.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ providers:
135135
provider_type: inline::llama-guard
136136
config:
137137
excluded_categories: []
138+
- provider_id: code-scanner
139+
provider_type: inline::code-scanner
138140
agents:
139141
- provider_id: meta-reference
140142
provider_type: inline::meta-reference
@@ -223,6 +225,9 @@ shields:
223225
- shield_id: llama-guard
224226
provider_id: ${env.SAFETY_MODEL:+llama-guard}
225227
provider_shield_id: ${env.SAFETY_MODEL:=}
228+
- shield_id: code-scanner
229+
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
230+
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
226231
vector_dbs: []
227232
datasets: []
228233
scoring_fns: []

llama_stack/distributions/starter/build.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ distribution_spec:
2828
- provider_type: inline::localfs
2929
safety:
3030
- provider_type: inline::llama-guard
31+
- provider_type: inline::code-scanner
3132
agents:
3233
- provider_type: inline::meta-reference
3334
telemetry:

llama_stack/distributions/starter/run.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ providers:
135135
provider_type: inline::llama-guard
136136
config:
137137
excluded_categories: []
138+
- provider_id: code-scanner
139+
provider_type: inline::code-scanner
138140
agents:
139141
- provider_id: meta-reference
140142
provider_type: inline::meta-reference
@@ -223,6 +225,9 @@ shields:
223225
- shield_id: llama-guard
224226
provider_id: ${env.SAFETY_MODEL:+llama-guard}
225227
provider_shield_id: ${env.SAFETY_MODEL:=}
228+
- shield_id: code-scanner
229+
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
230+
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
226231
vector_dbs: []
227232
datasets: []
228233
scoring_fns: []

llama_stack/distributions/starter/starter.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,14 @@
1515
ToolGroupInput,
1616
)
1717
from llama_stack.core.utils.dynamic import instantiate_class_type
18-
from llama_stack.distributions.template import (
19-
DistributionTemplate,
20-
RunConfigSettings,
21-
)
18+
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
2219
from llama_stack.providers.datatypes import RemoteProviderSpec
2320
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
2421
from llama_stack.providers.inline.inference.sentence_transformers import (
2522
SentenceTransformersInferenceConfig,
2623
)
2724
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
28-
from llama_stack.providers.inline.vector_io.milvus.config import (
29-
MilvusVectorIOConfig,
30-
)
25+
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
3126
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
3227
SQLiteVectorIOConfig,
3328
)
@@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
119114
BuildProvider(provider_type="remote::pgvector"),
120115
],
121116
"files": [BuildProvider(provider_type="inline::localfs")],
122-
"safety": [BuildProvider(provider_type="inline::llama-guard")],
117+
"safety": [
118+
BuildProvider(provider_type="inline::llama-guard"),
119+
BuildProvider(provider_type="inline::code-scanner"),
120+
],
123121
"agents": [BuildProvider(provider_type="inline::meta-reference")],
124122
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
125123
"post_training": [BuildProvider(provider_type="inline::huggingface")],
@@ -170,6 +168,11 @@ def get_distribution_template() -> DistributionTemplate:
170168
provider_id="${env.SAFETY_MODEL:+llama-guard}",
171169
provider_shield_id="${env.SAFETY_MODEL:=}",
172170
),
171+
ShieldInput(
172+
shield_id="code-scanner",
173+
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
174+
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
175+
),
173176
]
174177

175178
return DistributionTemplate(

llama_stack/providers/inline/safety/code_scanner/code_scanner.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
# the root directory of this source tree.
66

77
import logging
8-
from typing import Any
8+
import uuid
9+
from typing import TYPE_CHECKING, Any
10+
11+
if TYPE_CHECKING:
12+
from codeshield.cs import CodeShieldScanResult
913

1014
from llama_stack.apis.inference import Message
1115
from llama_stack.apis.safety import (
@@ -14,6 +18,7 @@
1418
SafetyViolation,
1519
ViolationLevel,
1620
)
21+
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
1722
from llama_stack.apis.shields import Shield
1823
from llama_stack.providers.utils.inference.prompt_adapter import (
1924
interleaved_content_as_str,
@@ -24,8 +29,8 @@
2429
log = logging.getLogger(__name__)
2530

2631
ALLOWED_CODE_SCANNER_MODEL_IDS = [
27-
"CodeScanner",
28-
"CodeShield",
32+
"code-scanner",
33+
"code-shield",
2934
]
3035

3136

@@ -69,3 +74,55 @@ async def run_shield(
6974
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
7075
)
7176
return RunShieldResponse(violation=violation)
77+
78+
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
79+
categories = {}
80+
category_scores = {}
81+
category_applied_input_types = {}
82+
83+
flagged = scan_result.is_insecure
84+
user_message = None
85+
metadata = {}
86+
87+
if scan_result.is_insecure:
88+
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
89+
categories = dict.fromkeys(pattern_ids, True)
90+
category_scores = dict.fromkeys(pattern_ids, 1.0)
91+
category_applied_input_types = {key: ["text"] for key in pattern_ids}
92+
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
93+
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
94+
95+
return ModerationObjectResults(
96+
flagged=flagged,
97+
categories=categories,
98+
category_scores=category_scores,
99+
category_applied_input_types=category_applied_input_types,
100+
user_message=user_message,
101+
metadata=metadata,
102+
)
103+
104+
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
105+
inputs = input if isinstance(input, list) else [input]
106+
results = []
107+
108+
from codeshield.cs import CodeShield
109+
110+
for text_input in inputs:
111+
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
112+
try:
113+
scan_result = await CodeShield.scan_code(text_input)
114+
moderation_result = self.get_moderation_object_results(scan_result)
115+
except Exception as e:
116+
log.error(f"CodeShield.scan_code failed: {e}")
117+
# create safe fallback response on scanner failure to avoid blocking legitimate requests
118+
moderation_result = ModerationObjectResults(
119+
flagged=False,
120+
categories={},
121+
category_scores={},
122+
category_applied_input_types={},
123+
user_message=None,
124+
metadata={"scanner_error": str(e)},
125+
)
126+
results.append(moderation_result)
127+
128+
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)

llama_stack/providers/inline/safety/llama_guard/llama_guard.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from typing import Any
1212

1313
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
14-
from llama_stack.apis.inference import (
15-
Inference,
16-
Message,
17-
UserMessage,
18-
)
14+
from llama_stack.apis.inference import Inference, Message, UserMessage
1915
from llama_stack.apis.safety import (
2016
RunShieldResponse,
2117
Safety,
@@ -72,7 +68,6 @@
7268
}
7369
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
7470

75-
7671
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
7772
CAT_VIOLENT_CRIMES,
7873
CAT_NON_VIOLENT_CRIMES,
@@ -460,7 +455,7 @@ def create_moderation_object(self, model: str, unsafe_code: str | None = None) -
460455

461456
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
462457
"""Check if content is safe based on response and unsafe code."""
463-
if response.strip() == SAFE_RESPONSE:
458+
if response.strip().lower().startswith(SAFE_RESPONSE):
464459
return True
465460

466461
if unsafe_code:

tests/integration/safety/test_safety.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def data_url_from_image(file_path):
2727

2828
@pytest.fixture(scope="session")
2929
def code_scanner_shield_id(available_shields):
30-
if "CodeScanner" in available_shields:
31-
return "CodeScanner"
30+
if "code-scanner" in available_shields:
31+
return "code-scanner"
3232

33-
pytest.skip("CodeScanner shield is not available. Skipping.")
33+
pytest.skip("code-scanner shield is not available. Skipping.")
3434

3535

3636
def test_unsafe_examples(client_with_models, shield_id):
@@ -163,6 +163,60 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo
163163
assert response.violation.user_message == "Sorry, I found security concerns in the code."
164164

165165

166+
def test_unsafe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
167+
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
168+
if not code_safety_enabled:
169+
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
170+
171+
insecure_code = [
172+
"""
173+
Here is an example of code that hashes a password using MD5:
174+
175+
```
176+
import hashlib
177+
password = 'securepassword123'.encode()
178+
hashed_password = hashlib.md5(password).hexdigest()
179+
print("Insecure hashed password:", hashed_password)
180+
```
181+
"""
182+
]
183+
moderation_object = client_with_models.moderations.create(
184+
input=insecure_code,
185+
model=code_scanner_shield_id,
186+
)
187+
assert moderation_object.results[0].flagged is True, f"Code scanner should have flagged {insecure_code} as insecure"
188+
assert all(value is True for value in moderation_object.results[0].categories.values()), (
189+
"Code scanner shield should have detected code insecure category"
190+
)
191+
192+
193+
def test_safe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
194+
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
195+
if not code_safety_enabled:
196+
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
197+
198+
secure_code = [
199+
"""
200+
Extract the first 5 characters from a string:
201+
```
202+
text = "Hello World"
203+
first_five = text[:5]
204+
print(first_five) # Output: "Hello"
205+
206+
# Safe handling for strings shorter than 5 characters
207+
def get_first_five(text):
208+
return text[:5] if text else ""
209+
```
210+
"""
211+
]
212+
moderation_object = client_with_models.moderations.create(
213+
input=secure_code,
214+
model=code_scanner_shield_id,
215+
)
216+
217+
assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure"
218+
219+
166220
# We can use an instance of the LlamaGuard shield to detect attempts to misuse
167221
# the interpreter as this is one of the existing categories it checks for
168222
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):

0 commit comments

Comments
 (0)