Skip to content

Commit a1e71a9

Browse files
authored
Merge pull request #75 from guardrails-ai/dtam/feature/fast_api
migrate from flask to fast api for uvicorn and asgi support
2 parents f20f793 + a2d9921 commit a1e71a9

35 files changed

+1294
-1599
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ COPY . .
4545
EXPOSE 8000
4646

4747
# This is our start command; yours might be different.
48-
# The guardrails-api is a standard Flask application.
49-
# You can use whatever production server you want that support Flask.
48+
# The guardrails-api is a standard FastAPI application.
49+
# You can use whatever production server you want that support FastAPI.
5050
# Here we use gunicorn
5151
CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --workers=2 'guardrails_api.app:create_app(".env", "sample-config.py")'

guardrails_api/api/guards.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
import json
2+
import os
3+
import inspect
4+
from typing import Any, Dict, Optional
5+
from fastapi import HTTPException, Request, APIRouter
6+
from fastapi.responses import JSONResponse, StreamingResponse
7+
from urllib.parse import unquote_plus
8+
from guardrails import AsyncGuard, Guard
9+
from guardrails.classes import ValidationOutcome
10+
from opentelemetry.trace import Span
11+
from guardrails_api_client import Guard as GuardStruct
12+
from guardrails_api.clients.cache_client import CacheClient
13+
from guardrails_api.clients.memory_guard_client import MemoryGuardClient
14+
from guardrails_api.clients.pg_guard_client import PGGuardClient
15+
from guardrails_api.clients.postgres_client import postgres_is_enabled
16+
from guardrails_api.utils.get_llm_callable import get_llm_callable
17+
from guardrails_api.utils.openai import (
18+
outcome_to_chat_completion,
19+
outcome_to_stream_response,
20+
)
21+
from guardrails_api.utils.handle_error import handle_error
22+
from string import Template
23+
24+
# if no pg_host is set, use in memory guards
25+
if postgres_is_enabled():
26+
guard_client = PGGuardClient()
27+
else:
28+
guard_client = MemoryGuardClient()
29+
# Will be defined at runtime
30+
import config # noqa
31+
32+
exports = config.__dir__()
33+
for export_name in exports:
34+
export = getattr(config, export_name)
35+
is_guard = isinstance(export, Guard)
36+
if is_guard:
37+
guard_client.create_guard(export)
38+
39+
cache_client = CacheClient()
40+
41+
cache_client.initialize()
42+
43+
router = APIRouter()
44+
45+
46+
@router.get("/guards")
47+
@handle_error
48+
async def get_guards():
49+
guards = guard_client.get_guards()
50+
return [g.to_dict() for g in guards]
51+
52+
53+
@router.post("/guards")
54+
@handle_error
55+
async def create_guard(guard: GuardStruct):
56+
if not postgres_is_enabled():
57+
raise HTTPException(
58+
status_code=501,
59+
detail="Not Implemented POST /guards is not implemented for in-memory guards.",
60+
)
61+
new_guard = guard_client.create_guard(guard)
62+
return new_guard.to_dict()
63+
64+
65+
@router.get("/guards/{guard_name}")
66+
@handle_error
67+
async def get_guard(guard_name: str, asOf: Optional[str] = None):
68+
decoded_guard_name = unquote_plus(guard_name)
69+
guard = guard_client.get_guard(decoded_guard_name, asOf)
70+
if guard is None:
71+
raise HTTPException(
72+
status_code=404,
73+
detail=f"A Guard with the name {decoded_guard_name} does not exist!",
74+
)
75+
return guard.to_dict()
76+
77+
78+
@router.put("/guards/{guard_name}")
79+
@handle_error
80+
async def update_guard(guard_name: str, guard: GuardStruct):
81+
if not postgres_is_enabled():
82+
raise HTTPException(
83+
status_code=501,
84+
detail="PUT /<guard_name> is not implemented for in-memory guards.",
85+
)
86+
decoded_guard_name = unquote_plus(guard_name)
87+
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
88+
return updated_guard.to_dict()
89+
90+
91+
@router.delete("/guards/{guard_name}")
92+
@handle_error
93+
async def delete_guard(guard_name: str):
94+
if not postgres_is_enabled():
95+
raise HTTPException(
96+
status_code=501,
97+
detail="DELETE /<guard_name> is not implemented for in-memory guards.",
98+
)
99+
decoded_guard_name = unquote_plus(guard_name)
100+
guard = guard_client.delete_guard(decoded_guard_name)
101+
return guard.to_dict()
102+
103+
104+
@router.post("/guards/{guard_name}/openai/v1/chat/completions")
105+
@handle_error
106+
async def openai_v1_chat_completions(guard_name: str, request: Request):
107+
payload = await request.json()
108+
decoded_guard_name = unquote_plus(guard_name)
109+
guard_struct = guard_client.get_guard(decoded_guard_name)
110+
if guard_struct is None:
111+
raise HTTPException(
112+
status_code=404,
113+
detail=f"A Guard with the name {decoded_guard_name} does not exist!",
114+
)
115+
116+
guard = (
117+
Guard.from_dict(guard_struct.to_dict())
118+
if not isinstance(guard_struct, Guard)
119+
else guard_struct
120+
)
121+
stream = payload.get("stream", False)
122+
has_tool_gd_tool_call = any(
123+
tool.get("function", {}).get("name") == "gd_response_tool"
124+
for tool in payload.get("tools", [])
125+
)
126+
127+
if not stream:
128+
validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload)
129+
llm_response = guard.history.last.iterations.last.outputs.llm_response_info
130+
result = outcome_to_chat_completion(
131+
validation_outcome=validation_outcome,
132+
llm_response=llm_response,
133+
has_tool_gd_tool_call=has_tool_gd_tool_call,
134+
)
135+
return JSONResponse(content=result)
136+
else:
137+
138+
async def openai_streamer():
139+
guard_stream = guard(num_reasks=0, **payload)
140+
for result in guard_stream:
141+
chunk = json.dumps(
142+
outcome_to_stream_response(validation_outcome=result)
143+
)
144+
yield f"data: {chunk}\n\n"
145+
yield "\n"
146+
147+
return StreamingResponse(openai_streamer(), media_type="text/event-stream")
148+
149+
150+
@router.post("/guards/{guard_name}/validate")
151+
@handle_error
152+
async def validate(guard_name: str, request: Request):
153+
payload = await request.json()
154+
openai_api_key = request.headers.get(
155+
"x-openai-api-key", os.environ.get("OPENAI_API_KEY")
156+
)
157+
decoded_guard_name = unquote_plus(guard_name)
158+
guard_struct = guard_client.get_guard(decoded_guard_name)
159+
160+
llm_output = payload.pop("llmOutput", None)
161+
num_reasks = payload.pop("numReasks", None)
162+
prompt_params = payload.pop("promptParams", {})
163+
llm_api = payload.pop("llmApi", None)
164+
args = payload.pop("args", [])
165+
stream = payload.pop("stream", False)
166+
167+
payload["api_key"] = payload.get("api_key", openai_api_key)
168+
169+
if llm_api is not None:
170+
llm_api = get_llm_callable(llm_api)
171+
if openai_api_key is None:
172+
raise HTTPException(
173+
status_code=400,
174+
detail="Cannot perform calls to OpenAI without an api key.",
175+
)
176+
177+
guard = guard_struct
178+
is_async = inspect.iscoroutinefunction(llm_api)
179+
180+
if not isinstance(guard_struct, Guard):
181+
if is_async:
182+
guard = AsyncGuard.from_dict(guard_struct.to_dict())
183+
else:
184+
guard: Guard = Guard.from_dict(guard_struct.to_dict())
185+
elif is_async:
186+
guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict())
187+
188+
if llm_api is None and num_reasks and num_reasks > 1:
189+
raise HTTPException(
190+
status_code=400,
191+
detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).",
192+
)
193+
194+
if llm_output is not None:
195+
if stream:
196+
raise HTTPException(
197+
status_code=400, detail="Streaming is not supported for parse calls!"
198+
)
199+
result: ValidationOutcome = guard.parse(
200+
llm_output=llm_output,
201+
num_reasks=num_reasks,
202+
prompt_params=prompt_params,
203+
llm_api=llm_api,
204+
**payload,
205+
)
206+
else:
207+
if stream:
208+
209+
async def guard_streamer():
210+
guard_stream = guard(
211+
llm_api=llm_api,
212+
prompt_params=prompt_params,
213+
num_reasks=num_reasks,
214+
stream=stream,
215+
*args,
216+
**payload,
217+
)
218+
for result in guard_stream:
219+
validation_output = ValidationOutcome.from_guard_history(
220+
guard.history.last
221+
)
222+
yield validation_output, result
223+
224+
async def validate_streamer(guard_iter):
225+
async for validation_output, result in guard_iter:
226+
fragment_dict = result.to_dict()
227+
fragment_dict["error_spans"] = [
228+
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
229+
for x in guard.error_spans_in_output()
230+
]
231+
yield json.dumps(fragment_dict) + "\n"
232+
233+
call = guard.history.last
234+
final_validation_output = ValidationOutcome(
235+
callId=call.id,
236+
validation_passed=result.validation_passed,
237+
validated_output=result.validated_output,
238+
history=guard.history,
239+
raw_llm_output=result.raw_llm_output,
240+
)
241+
final_output_dict = final_validation_output.to_dict()
242+
final_output_dict["error_spans"] = [
243+
json.dumps({"start": x.start, "end": x.end, "reason": x.reason})
244+
for x in guard.error_spans_in_output()
245+
]
246+
yield json.dumps(final_output_dict) + "\n"
247+
248+
serialized_history = [call.to_dict() for call in guard.history]
249+
cache_key = f"{guard.name}-{final_validation_output.call_id}"
250+
await cache_client.set(cache_key, serialized_history, 300)
251+
252+
return StreamingResponse(
253+
validate_streamer(guard_streamer()), media_type="application/json"
254+
)
255+
else:
256+
if inspect.iscoroutinefunction(guard):
257+
result: ValidationOutcome = await guard(
258+
llm_api=llm_api,
259+
prompt_params=prompt_params,
260+
num_reasks=num_reasks,
261+
*args,
262+
**payload,
263+
)
264+
else:
265+
result: ValidationOutcome = guard(
266+
llm_api=llm_api,
267+
prompt_params=prompt_params,
268+
num_reasks=num_reasks,
269+
*args,
270+
**payload,
271+
)
272+
273+
serialized_history = [call.to_dict() for call in guard.history]
274+
cache_key = f"{guard.name}-{result.call_id}"
275+
await cache_client.set(cache_key, serialized_history, 300)
276+
return result.to_dict()
277+
278+
279+
@router.get("/guards/{guard_name}/history/{call_id}")
280+
@handle_error
281+
async def guard_history(guard_name: str, call_id: str):
282+
cache_key = f"{guard_name}-{call_id}"
283+
return await cache_client.get(cache_key)
284+
285+
286+
def collect_telemetry(
287+
*,
288+
guard: Guard,
289+
validate_span: Span,
290+
validation_output: ValidationOutcome,
291+
prompt_params: Dict[str, Any],
292+
result: ValidationOutcome,
293+
):
294+
# Below is all telemetry collection and
295+
# should have no impact on what is returned to the user
296+
prompt = guard.history.last.inputs.prompt
297+
if prompt:
298+
prompt = Template(prompt).safe_substitute(**prompt_params)
299+
validate_span.set_attribute("prompt", prompt)
300+
301+
instructions = guard.history.last.inputs.instructions
302+
if instructions:
303+
instructions = Template(instructions).safe_substitute(**prompt_params)
304+
validate_span.set_attribute("instructions", instructions)
305+
306+
validate_span.set_attribute("validation_status", guard.history.last.status)
307+
validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output)
308+
309+
# Use the serialization from the class instead of re-writing it
310+
valid_output: str = (
311+
json.dumps(validation_output.validated_output)
312+
if isinstance(validation_output.validated_output, dict)
313+
else str(validation_output.validated_output)
314+
)
315+
validate_span.set_attribute("validated_output", valid_output)
316+
317+
validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed)
318+
319+
num_of_reasks = (
320+
guard.history.last.iterations.length - 1
321+
if guard.history.last.iterations.length > 0
322+
else 0
323+
)
324+
validate_span.set_attribute("num_of_reasks", num_of_reasks)

0 commit comments

Comments
 (0)