Skip to content

Commit 7d78643

Browse files
authored
Merge pull request #1120 from guardrails-ai/ai-proxy-updates
updates for telemetry performance and streaming and server multi node support
2 parents 30eeb4c + 7386fae commit 7d78643

File tree

13 files changed

+124
-31
lines changed

13 files changed

+124
-31
lines changed

guardrails/api_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def stream_validate(
104104
)
105105
if line:
106106
json_output = json.loads(line)
107+
if json_output.get("error"):
108+
raise Exception(json_output.get("error").get("message"))
107109
yield IValidationOutcome.from_dict(json_output)
108110

109111
def get_history(self, guard_name: str, call_id: str):

guardrails/async_guard.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -574,13 +574,15 @@ async def _stream_server_call(
574574
validated_output=validated_output,
575575
validation_passed=(validation_output.validation_passed is True),
576576
)
577-
if validation_output:
578-
guard_history = self._api_client.get_history(
579-
self.name, validation_output.call_id
580-
)
581-
self.history.extend(
582-
[Call.from_interface(call) for call in guard_history]
583-
)
577+
# TODO re-enable this once we have a way to get history
578+
# from a multi-node server
579+
# if validation_output:
580+
# guard_history = self._api_client.get_history(
581+
# self.name, validation_output.call_id
582+
# )
583+
# self.history.extend(
584+
# [Call.from_interface(call) for call in guard_history]
585+
# )
584586
else:
585587
raise ValueError("AsyncGuard does not have an api client!")
586588

guardrails/classes/llm/llm_response.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ def to_interface(self) -> ILLMResponse:
4747
stream_output = [str(so) for so in copy_2]
4848

4949
async_stream_output = None
50-
if self.async_stream_output:
50+
# dont do this again if already aiter-able were updating
51+
# ourselves here so in memory
52+
# this can cause issues
53+
if self.async_stream_output and not hasattr(
54+
self.async_stream_output, "__aiter__"
55+
):
5156
# tee doesn't work with async iterators
5257
# This may be destructive
5358
async_stream_output = []

guardrails/cli/create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def generate_template_config(
118118
guard_instantiations = []
119119

120120
for i, guard in enumerate(template["guards"]):
121-
guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])")
121+
guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])")
122122
guard_instantiations = "\n".join(guard_instantiations)
123123
# Interpolate variables
124124
output_content = template_content.format(

guardrails/cli/hub/template_config.py.template

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from guardrails import Guard
3+
from guardrails import AsyncGuard, Guard
44
from guardrails.hub import {VALIDATOR_IMPORTS}
55

66
try:

guardrails/guard.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,10 +1215,12 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
12151215
error="The response from the server was empty!",
12161216
)
12171217

1218-
guard_history = self._api_client.get_history(
1219-
self.name, validation_output.call_id
1220-
)
1221-
self.history.extend([Call.from_interface(call) for call in guard_history])
1218+
# TODO reenable this when we have history support in
1219+
# multi-node server environments
1220+
# guard_history = self._api_client.get_history(
1221+
# self.name, validation_output.call_id
1222+
# )
1223+
# self.history.extend([Call.from_interface(call) for call in guard_history])
12221224

12231225
validation_summaries = []
12241226
if self.history.last and self.history.last.iterations.last:
@@ -1281,13 +1283,15 @@ def _stream_server_call(
12811283
validated_output=validated_output,
12821284
validation_passed=(validation_output.validation_passed is True),
12831285
)
1284-
if validation_output:
1285-
guard_history = self._api_client.get_history(
1286-
self.name, validation_output.call_id
1287-
)
1288-
self.history.extend(
1289-
[Call.from_interface(call) for call in guard_history]
1290-
)
1286+
1287+
# TODO reenable this when sever supports multi-node history
1288+
# if validation_output:
1289+
# guard_history = self._api_client.get_history(
1290+
# self.name, validation_output.call_id
1291+
# )
1292+
# self.history.extend(
1293+
# [Call.from_interface(call) for call in guard_history]
1294+
# )
12911295
else:
12921296
raise ValueError("Guard does not have an api client!")
12931297

guardrails/llm_providers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,10 @@ def _invoke_llm(
498498
),
499499
)
500500

501+
# these are gr only and should not be getting passed to llms
502+
kwargs.pop("reask_prompt", None)
503+
kwargs.pop("reask_instructions", None)
504+
501505
response = completion(
502506
model=model,
503507
*args,
@@ -1088,6 +1092,10 @@ async def invoke_llm(
10881092
),
10891093
)
10901094

1095+
# these are gr only and should not be getting passed to llms
1096+
kwargs.pop("reask_prompt", None)
1097+
kwargs.pop("reask_instructions", None)
1098+
10911099
response = await acompletion(
10921100
*args,
10931101
**kwargs,

guardrails/run/async_stream_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ async def async_step(
148148
_ = self.is_last_chunk(chunk, api)
149149

150150
fragment += chunk_text
151+
151152
results = await validator_service.async_partial_validate(
152153
chunk_text,
153154
self.metadata,
@@ -157,7 +158,8 @@ async def async_step(
157158
"$",
158159
True,
159160
)
160-
validators = self.validation_map["$"] or []
161+
validators = self.validation_map.get("$", [])
162+
161163
# collect the result validated_chunk into validation progress
162164
# per validator
163165
for result in results:
@@ -210,7 +212,7 @@ async def async_step(
210212
validation_progress[validator_log.validator_name] += chunk
211213
# if there is an entry for every validator
212214
# run a merge and emit a validation outcome
213-
if len(validation_progress) == len(validators):
215+
if len(validation_progress) == len(validators) or len(validators) == 0:
214216
if refrain_triggered:
215217
current = ""
216218
else:

guardrails/telemetry/guard_tracing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,18 @@ def trace_stream_guard(
145145
res = next(result) # type: ignore
146146
# FIXME: This should only be called once;
147147
# Accumulate the validated output and call at the end
148-
add_guard_attributes(guard_span, history, res)
149-
add_user_attributes(guard_span)
150-
yield res
148+
if not guard_span.is_recording():
149+
# Assuming you have a tracer instance
150+
tracer = get_tracer(__name__)
151+
# Create a new span and link it to the previous span
152+
with tracer.start_as_current_span(
153+
"stream_guard_span", # type: ignore
154+
links=[Link(guard_span.get_span_context())],
155+
) as new_span:
156+
guard_span = new_span
157+
add_guard_attributes(guard_span, history, res)
158+
add_user_attributes(guard_span)
159+
yield res
151160
except StopIteration:
152161
next_exists = False
153162

@@ -180,6 +189,7 @@ def trace_guard_execution(
180189
result, ValidationOutcome
181190
):
182191
return trace_stream_guard(guard_span, result, history)
192+
183193
add_guard_attributes(guard_span, history, result)
184194
add_user_attributes(guard_span)
185195
return result
@@ -204,14 +214,14 @@ async def trace_async_stream_guard(
204214
tracer = get_tracer(__name__)
205215
# Create a new span and link it to the previous span
206216
with tracer.start_as_current_span(
207-
"new_guard_span", # type: ignore
217+
"async_stream_span", # type: ignore
208218
links=[Link(guard_span.get_span_context())],
209219
) as new_span:
210220
guard_span = new_span
211221

212222
add_guard_attributes(guard_span, history, res)
213223
add_user_attributes(guard_span)
214-
yield res
224+
yield res
215225
except StopIteration:
216226
next_exists = False
217227
except StopAsyncIteration:

guardrails/telemetry/runner_tracing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs):
265265
) as call_span:
266266
try:
267267
response = fn(*args, **kwargs)
268+
if isinstance(response, LLMResponse) and (
269+
response.async_stream_output or response.stream_output
270+
):
271+
# TODO: Iterate, add a call attr each time
272+
return response
268273
add_call_attributes(call_span, response, *args, **kwargs)
269274
return response
270275
except Exception as e:

0 commit comments

Comments
 (0)