Skip to content

Commit b5b5f5b

Browse files
authored
chore: add mypy prompt guard (#2678)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR adds static type coverage to `llama-stack` Part of #2647 <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Mustafa Elbehery <[email protected]>
1 parent 7448a4a commit b5b5f5b

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
RunShieldResponse,
1616
Safety,
1717
SafetyViolation,
18+
ShieldStore,
1819
ViolationLevel,
1920
)
2021
from llama_stack.apis.shields import Shield
@@ -32,6 +33,8 @@
3233

3334

3435
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
36+
shield_store: ShieldStore
37+
3538
def __init__(self, config: PromptGuardConfig, _deps) -> None:
3639
self.config = config
3740

@@ -53,7 +56,7 @@ async def run_shield(
5356
self,
5457
shield_id: str,
5558
messages: list[Message],
56-
params: dict[str, Any] = None,
59+
params: dict[str, Any],
5760
) -> RunShieldResponse:
5861
shield = await self.shield_store.get_shield(shield_id)
5962
if not shield:
@@ -117,8 +120,10 @@ async def run(self, messages: list[Message]) -> RunShieldResponse:
117120
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
118121
violation = SafetyViolation(
119122
violation_level=ViolationLevel.ERROR,
120-
violation_type=f"prompt_injection:malicious={score_malicious}",
121-
violation_return_message="Sorry, I cannot do this.",
123+
user_message="Sorry, I cannot do this.",
124+
metadata={
125+
"violation_type": f"prompt_injection:malicious={score_malicious}",
126+
},
122127
)
123128

124129
return RunShieldResponse(violation=violation)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ exclude = [
266266
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
267267
"^llama_stack/providers/inline/safety/code_scanner/",
268268
"^llama_stack/providers/inline/safety/llama_guard/",
269-
"^llama_stack/providers/inline/safety/prompt_guard/",
270269
"^llama_stack/providers/inline/scoring/basic/",
271270
"^llama_stack/providers/inline/scoring/braintrust/",
272271
"^llama_stack/providers/inline/scoring/llm_as_judge/",

0 commit comments

Comments
 (0)