6
6
7
7
import logging
8
8
import uuid
9
- from typing import Any
9
+ from typing import TYPE_CHECKING , Any
10
10
11
- from codeshield .cs import CodeShield , CodeShieldScanResult
11
+ if TYPE_CHECKING :
12
+ from codeshield .cs import CodeShieldScanResult
12
13
13
14
from llama_stack .apis .inference import Message
14
15
from llama_stack .apis .safety import (
@@ -59,6 +60,8 @@ async def run_shield(
59
60
if not shield :
60
61
raise ValueError (f"Shield { shield_id } not found" )
61
62
63
+ from codeshield .cs import CodeShield
64
+
62
65
text = "\n " .join ([interleaved_content_as_str (m .content ) for m in messages ])
63
66
log .info (f"Running CodeScannerShield on { text [50 :]} " )
64
67
result = await CodeShield .scan_code (text )
@@ -72,7 +75,7 @@ async def run_shield(
72
75
)
73
76
return RunShieldResponse (violation = violation )
74
77
75
- def get_moderation_object_results (self , scan_result : CodeShieldScanResult ) -> ModerationObjectResults :
78
+ def get_moderation_object_results (self , scan_result : " CodeShieldScanResult" ) -> ModerationObjectResults :
76
79
categories = {}
77
80
category_scores = {}
78
81
category_applied_input_types = {}
@@ -102,6 +105,8 @@ async def run_moderation(self, input: str | list[str], model: str) -> Moderation
102
105
inputs = input if isinstance (input , list ) else [input ]
103
106
results = []
104
107
108
+ from codeshield .cs import CodeShield
109
+
105
110
for text_input in inputs :
106
111
log .info (f"Running CodeScannerShield moderation on input: { text_input [:100 ]} ..." )
107
112
scan_result = await CodeShield .scan_code (text_input )
0 commit comments