Skip to content

Commit 6da7303

Browse files
committed
move imports to method
1 parent 0b429da commit 6da7303

File tree

4 files changed

+52
-19
lines changed

4 files changed

+52
-19
lines changed

llama_stack/providers/inline/safety/code_scanner/code_scanner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
import logging
88
import uuid
9-
from typing import Any
9+
from typing import TYPE_CHECKING, Any
1010

11-
from codeshield.cs import CodeShield, CodeShieldScanResult
11+
if TYPE_CHECKING:
12+
from codeshield.cs import CodeShieldScanResult
1213

1314
from llama_stack.apis.inference import Message
1415
from llama_stack.apis.safety import (
@@ -59,6 +60,8 @@ async def run_shield(
5960
if not shield:
6061
raise ValueError(f"Shield {shield_id} not found")
6162

63+
from codeshield.cs import CodeShield
64+
6265
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
6366
log.info(f"Running CodeScannerShield on {text[50:]}")
6467
result = await CodeShield.scan_code(text)
@@ -72,7 +75,7 @@ async def run_shield(
7275
)
7376
return RunShieldResponse(violation=violation)
7477

75-
def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults:
78+
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
7679
categories = {}
7780
category_scores = {}
7881
category_applied_input_types = {}
@@ -102,6 +105,8 @@ async def run_moderation(self, input: str | list[str], model: str) -> Moderation
102105
inputs = input if isinstance(input, list) else [input]
103106
results = []
104107

108+
from codeshield.cs import CodeShield
109+
105110
for text_input in inputs:
106111
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
107112
scan_result = await CodeShield.scan_code(text_input)

llama_stack/providers/inline/safety/llama_guard/llama_guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def create_moderation_object(self, model: str, unsafe_code: str | None = None) -
455455

456456
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
457457
"""Check if content is safe based on response and unsafe code."""
458-
if response.strip() == SAFE_RESPONSE:
458+
if response.strip().lower().startswith(SAFE_RESPONSE):
459459
return True
460460

461461
if unsafe_code:

pyproject.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ classifiers = [
2525
]
2626
dependencies = [
2727
"aiohttp",
28-
"fastapi>=0.115.0,<1.0", # server
29-
"fire", # for MCP in LLS client
28+
"fastapi>=0.115.0,<1.0", # server
29+
"fire", # for MCP in LLS client
3030
"httpx",
3131
"huggingface-hub>=0.34.0,<1.0",
3232
"jinja2>=3.1.6",
@@ -44,12 +44,12 @@ dependencies = [
4444
"tiktoken",
4545
"pillow",
4646
"h11>=0.16.0",
47-
"python-multipart>=0.0.20", # For fastapi Form
48-
"uvicorn>=0.34.0", # server
49-
"opentelemetry-sdk>=1.30.0", # server
47+
"python-multipart>=0.0.20", # For fastapi Form
48+
"uvicorn>=0.34.0", # server
49+
"opentelemetry-sdk>=1.30.0", # server
5050
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
51-
"aiosqlite>=0.21.0", # server - for metadata store
52-
"asyncpg", # for metadata store
51+
"aiosqlite>=0.21.0", # server - for metadata store
52+
"asyncpg", # for metadata store
5353
]
5454

5555
[project.optional-dependencies]
@@ -163,6 +163,7 @@ explicit = true
163163
[tool.uv.sources]
164164
torch = [{ index = "pytorch-cpu" }]
165165
torchvision = [{ index = "pytorch-cpu" }]
166+
llama-stack-client = { path = "../llama-stack-client-python" }
166167

167168
[tool.ruff]
168169
line-length = 120

uv.lock

Lines changed: 35 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)