Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 83 additions & 85 deletions backend/app/api/routes/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,28 @@
from guardrails.validators import FailResult

from app.api.deps import AuthDep, SessionDep
from app.core.constants import REPHRASE_ON_FAIL_PREFIX
from app.core.guardrail_controller import build_guard, get_validator_config_models
from app.crud.request_log import RequestLogCrud
from app.crud.validator_log import ValidatorLogCrud
from app.models.guardrail_config import GuardrailInputRequest, GuardrailOutputRequest
from app.models.guardrail_config import GuardrailRequest, GuardrailResponse
from app.models.logging.request import RequestLogUpdate, RequestStatus
from app.models.logging.validator import ValidatorLog, ValidatorOutcome
from app.utils import APIResponse

router = APIRouter(prefix="/guardrails", tags=["guardrails"])

@router.post("/input/")
async def run_input_guardrails(
payload: GuardrailInputRequest,
@router.post(
"/",
response_model=APIResponse[GuardrailResponse],
response_model_exclude_none=True)
async def run_guardrails(
payload: GuardrailRequest,
session: SessionDep,
_: AuthDep,
):
request_log_crud = RequestLogCrud(session=session)
validator_log_crud = ValidatorLogCrud(session=session)
request_id = None

try:
request_id = UUID(payload.request_id)
Expand All @@ -35,38 +38,12 @@ async def run_input_guardrails(
return await _validate_with_guard(
payload.input,
payload.validators,
"safe_input",
request_log_crud,
request_log.id,
validator_log_crud,
)

@router.post("/output/")
async def run_output_guardrails(
payload: GuardrailOutputRequest,
session: SessionDep,
_: AuthDep,
):
request_log_crud = RequestLogCrud(session=session)
validator_log_crud = ValidatorLogCrud(session=session)
request_id = None

try:
request_id = UUID(payload.request_id)
except ValueError:
return APIResponse.failure_response(error="Invalid request_id")

request_log = request_log_crud.create(request_id, input_text=payload.output)
return await _validate_with_guard(
payload.output,
payload.validators,
"safe_output",
request_log_crud,
request_log.id,
validator_log_crud
)

@router.get("/validator/")
@router.get("/")
async def list_validators(_: AuthDep):
"""
Lists all validators and their parameters directly.
Expand All @@ -93,84 +70,105 @@ async def list_validators(_: AuthDep):
async def _validate_with_guard(
data: str,
validators: list,
response_field: str, # "safe_input" or "safe_output"
request_log_crud: RequestLogCrud,
request_log_id: UUID,
validator_log_crud: ValidatorLogCrud,
) -> APIResponse:
response_id = uuid.uuid4()
guard = None

try:
guard = build_guard(validators)
result = guard.validate(data)

if result.validated_output is not None:
request_log_crud.update(
request_log_id=request_log_id,
request_status=RequestStatus.SUCCESS,
request_log_update= RequestLogUpdate(
response_text=result.validated_output,
response_id=response_id
)
)

add_validator_logs(guard, request_log_id, validator_log_crud)
"""
Runs Guardrails validation on input/output data, persists request & validator logs,
and returns a structured APIResponse.

return APIResponse.success_response(
data={
"response_id": response_id,
response_field: result.validated_output,
}
)
This function treats validation failures as first-class outcomes (not exceptions),
while still safely handling unexpected runtime errors.
"""
response_id = uuid.uuid4()
guard: Guard | None = None

def _finalize(
*,
status: RequestStatus,
validated_output: str | None = None,
error_message: str | None = None,
) -> APIResponse:
"""
Single exit-point helper to ensure:
- request logs are always updated
- validator logs are written when available
- API responses are consistent
"""
response_text = (
validated_output if validated_output is not None else error_message
)
if response_text is None:
response_text = "Validation failed"

request_log_crud.update(
request_log_id=request_log_id,
request_status=RequestStatus.ERROR,
request_status=status,
request_log_update=RequestLogUpdate(
response_text=str(result),
response_text=response_text,
response_id=response_id,
),
)
add_validator_logs(guard, request_log_id, validator_log_crud)

if guard is not None:
add_validator_logs(guard, request_log_id, validator_log_crud)

rephrase_needed = (
validated_output is not None
and validated_output.startswith(REPHRASE_ON_FAIL_PREFIX)
)

response_model = GuardrailResponse(
response_id=response_id,
rephrase_needed=rephrase_needed,
safe_text=validated_output,
)

if status == RequestStatus.SUCCESS:
return APIResponse.success_response(data=response_model)

return APIResponse.failure_response(
data={
"response_id": response_id,
response_field: None,
},
error="Validation failed",
data=response_model,
error=response_text or "Validation failed",
)

except Exception as e:
request_log_crud.update(
request_log_id=request_log_id,
request_status=RequestStatus.ERROR,
request_log_update= RequestLogUpdate(
response_text=str(e),
response_id=response_id
)
try:
guard = build_guard(validators)
result = guard.validate(data)

# Case 1: validation passed OR failed-with-fix (on_fail=FIX)
if result.validated_output is not None:
return _finalize(
status=RequestStatus.SUCCESS,
validated_output=result.validated_output,
)

add_validator_logs(guard, request_log_id, validator_log_crud)
# Case 2: validation failed without a fix
return _finalize(
status=RequestStatus.ERROR,
error_message=str(result.error),
)

return APIResponse.failure_response(
data={
"response_id": response_id,
response_field: None,
},
error=str(e),
except Exception as exc:
# Case 3: unexpected system / runtime failure
return _finalize(
status=RequestStatus.ERROR,
error_message=str(exc),
)

def add_validator_logs(guard: Guard, request_log_id: UUID, validator_log_crud: ValidatorLogCrud):
if not guard or not guard.history or not guard.history.last:
history = getattr(guard, "history", None)
if not history:
return

call = guard.history.last
if not call.iterations:
last_call = getattr(history, "last", None)
if not last_call or not getattr(last_call, "iterations", None):
return

iteration = call.iterations[-1]
if not iteration.outputs or not iteration.outputs.validator_logs:
iteration = last_call.iterations[-1]
outputs = getattr(iteration, "outputs", None)
if not outputs or not getattr(outputs, "validator_logs", None):
return

for log in iteration.outputs.validator_logs:
Expand Down
4 changes: 3 additions & 1 deletion backend/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
LANG_HINDI = "hi"
LANG_ENGLISH = "en"
LABEL = "label"
SCORE = "score"
SCORE = "score"

REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content."
7 changes: 6 additions & 1 deletion backend/app/core/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ class BiasCategories(Enum):
Generic = "generic"
Healthcare = "healthcare"
Education = "education"
All = "all"
All = "all"

class GuardrailOnFail(Enum):
Exception = "exception"
Fix = "fix"
Rephrase = "rephrase"
6 changes: 6 additions & 0 deletions backend/app/core/on_fail_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from guardrails.validators import FailResult

from app.core.constants import REPHRASE_ON_FAIL_PREFIX

def rephrase_query_on_fail(value: str, fail_result: FailResult):
return f"{REPHRASE_ON_FAIL_PREFIX} {fail_result.error_message}"
32 changes: 18 additions & 14 deletions backend/app/models/base_validator_config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
from typing import Any, Literal, Optional

from guardrails import OnFailAction
from guardrails.validators import Validator
from sqlmodel import SQLModel

ON_FAIL_STR = Literal["exception", "fix", "noop", "reask"]
from app.core.enum import GuardrailOnFail
from app.core.on_fail_actions import rephrase_query_on_fail


_ON_FAIL_MAP = {
GuardrailOnFail.Fix: OnFailAction.FIX,
GuardrailOnFail.Exception: OnFailAction.EXCEPTION,
GuardrailOnFail.Rephrase: rephrase_query_on_fail,
}

class BaseValidatorConfig(SQLModel):
on_fail: Optional[ON_FAIL_STR] = OnFailAction.FIX
on_fail: GuardrailOnFail = GuardrailOnFail.Fix

model_config = {"arbitrary_types_allowed": True}

def resolve_on_fail(self):
if self.on_fail is None:
return None

try:
return OnFailAction[self.on_fail.upper()]
except KeyError:
return _ON_FAIL_MAP[self.on_fail]
except KeyError as e:
raise ValueError(
f"Invalid on_fail value: {self.on_fail}. "
"Expected one of: exception, fix, noop, reask"
)
f"Invalid on_fail value: {self.on_fail}. Error {e}. " \
"Expected one of: exception, fix, rephrase."
)

def build(self) -> Any:
def build(self) -> Validator:
raise NotImplementedError(
f"{self.__class__.__name__} must implement build()"
)
)
13 changes: 7 additions & 6 deletions backend/app/models/guardrail_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Union, Annotated
from typing import Annotated, List, Optional, Union
from uuid import UUID

from sqlmodel import Field, SQLModel

Expand All @@ -20,12 +21,12 @@
Field(discriminator="type")
]

class GuardrailInputRequest(SQLModel):
class GuardrailRequest(SQLModel):
request_id: str
input: str
validators: List[ValidatorConfigItem]

class GuardrailOutputRequest(SQLModel):
request_id: str
output: str
validators: List[ValidatorConfigItem]
class GuardrailResponse(SQLModel):
response_id: UUID
rephrase_needed: bool = False
safe_text: Optional[str] = None
Loading