Skip to content

feat: Code scanner Provider impl for moderations api #3100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 2 additions & 3 deletions llama_stack/core/routers/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

from typing import Any

from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
Expand Down Expand Up @@ -68,6 +66,7 @@ async def get_shield_id(self, model: str) -> str:
list_shields_response = await self.routing_table.list_shields()

matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]

if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/ci-tests/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
Expand Down
5 changes: 5 additions & 0 deletions llama_stack/distributions/ci-tests/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ providers:
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
Expand Down Expand Up @@ -215,6 +217,9 @@ shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/starter/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
Expand Down
5 changes: 5 additions & 0 deletions llama_stack/distributions/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ providers:
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
Expand Down Expand Up @@ -215,6 +217,9 @@ shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
Expand Down
19 changes: 11 additions & 8 deletions llama_stack/distributions/starter/starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,14 @@
ToolGroupInput,
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import (
DistributionTemplate,
RunConfigSettings,
)
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import (
MilvusVectorIOConfig,
)
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
Expand Down Expand Up @@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="remote::pgvector"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"safety": [
BuildProvider(provider_type="inline::llama-guard"),
BuildProvider(provider_type="inline::code-scanner"),
],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")],
Expand Down Expand Up @@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}",
),
ShieldInput(
shield_id="code-scanner",
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
]

return DistributionTemplate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
# the root directory of this source tree.

import logging
from typing import Any
import uuid
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult

from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
Expand All @@ -14,6 +18,7 @@
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
Expand All @@ -24,8 +29,8 @@
log = logging.getLogger(__name__)

ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
"code-scanner",
"code-shield",
]


Expand Down Expand Up @@ -69,3 +74,55 @@ async def run_shield(
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
)
return RunShieldResponse(violation=violation)

def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
categories = {}
category_scores = {}
category_applied_input_types = {}

flagged = scan_result.is_insecure
user_message = None
metadata = {}

if scan_result.is_insecure:
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
categories = dict.fromkeys(pattern_ids, True)
category_scores = dict.fromkeys(pattern_ids, 1.0)
category_applied_input_types = {key: ["text"] for key in pattern_ids}
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}

return ModerationObjectResults(
flagged=flagged,
categories=categories,
category_scores=category_scores,
category_applied_input_types=category_applied_input_types,
user_message=user_message,
metadata=metadata,
)

async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
inputs = input if isinstance(input, list) else [input]
results = []

from codeshield.cs import CodeShield

for text_input in inputs:
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
try:
scan_result = await CodeShield.scan_code(text_input)
moderation_result = self.get_moderation_object_results(scan_result)
except Exception as e:
log.error(f"CodeShield.scan_code failed: {e}")
# create safe fallback response on scanner failure to avoid blocking legitimate requests
moderation_result = ModerationObjectResults(
flagged=False,
categories={},
category_scores={},
category_applied_input_types={},
user_message=None,
metadata={"scanner_error": str(e)},
)
results.append(moderation_result)

return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from typing import Any

from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
Inference,
Message,
UserMessage,
)
from llama_stack.apis.inference import Inference, Message, UserMessage
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
Expand Down Expand Up @@ -72,7 +68,6 @@
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}


DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
Expand Down Expand Up @@ -460,7 +455,7 @@ def create_moderation_object(self, model: str, unsafe_code: str | None = None) -

def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
"""Check if content is safe based on response and unsafe code."""
if response.strip() == SAFE_RESPONSE:
if response.strip().lower().startswith(SAFE_RESPONSE):
return True

if unsafe_code:
Expand Down
60 changes: 57 additions & 3 deletions tests/integration/safety/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def data_url_from_image(file_path):

@pytest.fixture(scope="session")
def code_scanner_shield_id(available_shields):
if "CodeScanner" in available_shields:
return "CodeScanner"
if "code-scanner" in available_shields:
return "code-scanner"

pytest.skip("CodeScanner shield is not available. Skipping.")
pytest.skip("code-scanner shield is not available. Skipping.")


def test_unsafe_examples(client_with_models, shield_id):
Expand Down Expand Up @@ -163,6 +163,60 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo
assert response.violation.user_message == "Sorry, I found security concerns in the code."


def test_unsafe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
if not code_safety_enabled:
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")

insecure_code = [
"""
Here is an example of code that hashes a password using MD5:

```
import hashlib
password = 'securepassword123'.encode()
hashed_password = hashlib.md5(password).hexdigest()
print("Insecure hashed password:", hashed_password)
```
"""
]
moderation_object = client_with_models.moderations.create(
input=insecure_code,
model=code_scanner_shield_id,
)
assert moderation_object.results[0].flagged is True, f"Code scanner should have flagged {insecure_code} as insecure"
assert all(value is True for value in moderation_object.results[0].categories.values()), (
"Code scanner shield should have detected code insecure category"
)


def test_safe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
if not code_safety_enabled:
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")

secure_code = [
"""
Extract the first 5 characters from a string:
```
text = "Hello World"
first_five = text[:5]
print(first_five) # Output: "Hello"

# Safe handling for strings shorter than 5 characters
def get_first_five(text):
return text[:5] if text else ""
```
"""
]
moderation_object = client_with_models.moderations.create(
input=secure_code,
model=code_scanner_shield_id,
)

assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure"


# We can use an instance of the LlamaGuard shield to detect attempts to misuse
# the interpreter as this is one of the existing categories it checks for
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):
Expand Down
Loading