Skip to content

Gracefully handle errors in evals #2295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def evaluate_dataset():
report = dataset.evaluate_sync(infer_time_range)
print(report)

assertion_pass_rate = report.averages().assertions
averages = report.averages()
assert averages is not None
assertion_pass_rate = averages.assertions
assert assertion_pass_rate is not None, 'There should be at least one assertion'
assert assertion_pass_rate > 0.9, (
f'The assertion pass rate was {assertion_pass_rate:.1%}; it should be above 90%.'
Expand Down
217 changes: 132 additions & 85 deletions pydantic_evals/pydantic_evals/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import inspect
import sys
import time
import traceback
import warnings
from collections.abc import Awaitable, Mapping, Sequence
from contextlib import AsyncExitStack, nullcontext
from contextvars import ContextVar
from dataclasses import dataclass, field
from inspect import iscoroutinefunction
from pathlib import Path
from typing import Any, Callable, Generic, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast

import anyio
import logfire_api
Expand All @@ -40,10 +41,14 @@
from .evaluators._run_evaluator import run_evaluator
from .evaluators.common import DEFAULT_EVALUATORS
from .evaluators.context import EvaluatorContext
from .evaluators.evaluator import EvaluatorFailure
from .evaluators.spec import EvaluatorSpec
from .otel import SpanTree
from .otel._context_subtree import context_subtree
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure

if TYPE_CHECKING:
from tenacity import AsyncRetrying

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup # pragma: lax no cover
Expand Down Expand Up @@ -84,6 +89,7 @@


_REPORT_CASES_ADAPTER = TypeAdapter(list[ReportCase])
_REPORT_CASE_FAILURES_ADAPTER = TypeAdapter(list[ReportCaseFailure])
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter(ReportCaseAggregate)


Expand Down Expand Up @@ -171,11 +177,6 @@ def __init__(
self.evaluators = list(evaluators)


# TODO: Consider making one or more of the following changes to this type:
# * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
# * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
# * Rename to `Evaluation`
# TODO: Allow `task` to be sync _or_ async
class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True):
"""A dataset of test [cases][pydantic_evals.Case].

Expand Down Expand Up @@ -263,6 +264,7 @@ async def evaluate(
name: str | None = None,
max_concurrency: int | None = None,
progress: bool = True,
retry: AsyncRetrying | None = None,
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
"""Evaluates the test cases in the dataset using the given task.

Expand All @@ -277,6 +279,7 @@ async def evaluate(
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
If None, all cases will be evaluated concurrently.
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
retry: Optional retry configuration for the task execution.

Returns:
A report containing the results of the evaluation.
Expand All @@ -292,24 +295,30 @@ async def evaluate(

async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
async with limiter:
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators)
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators, retry)
if progress_bar and task_id is not None: # pragma: no branch
progress_bar.update(task_id, advance=1)
return result

cases_and_failures = await task_group_gather(
[
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
for i, case in enumerate(self.cases, 1)
]
)
report = EvaluationReport(
name=name,
cases=await task_group_gather(
[
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
for i, case in enumerate(self.cases, 1)
]
),
cases=[x for x in cases_and_failures if isinstance(x, ReportCase)],
failures=[x for x in cases_and_failures if isinstance(x, ReportCaseFailure)],
)
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
eval_span.set_attribute('cases', _REPORT_CASES_ADAPTER.dump_python(report.cases))
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
eval_span.set_attribute('failures', _REPORT_CASE_FAILURES_ADAPTER.dump_python(report.failures))
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(report.averages()))
averages = report.averages()
if averages:
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(averages))
return report

def evaluate_sync(
Expand Down Expand Up @@ -817,38 +826,55 @@ def record_attribute(self, name: str, value: Any) -> None:


async def _run_task(
task: Callable[[InputsT], Awaitable[OutputT] | OutputT], case: Case[InputsT, OutputT, MetadataT]
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
case: Case[InputsT, OutputT, MetadataT],
retry: AsyncRetrying | None = None,
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
"""Run a task on a case and return the context for evaluators.

Args:
task: The task to run.
case: The case to run the task on.
retry: The retry strategy to use.

Returns:
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.

Raises:
Exception: Any exception raised by the task.
"""
task_run = _TaskRun()
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
raise RuntimeError('A task run has already been entered. Task runs should not be nested')

# Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
# Should we handle them for the user in some way? If so, I guess we'd want to do that here.
token = _CURRENT_TASK_RUN.set(task_run)
try:
with _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span:
with context_subtree() as span_tree:
async def _run_once():
task_run_ = _TaskRun()
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
raise RuntimeError('A task run has already been entered. Task runs should not be nested')

token = _CURRENT_TASK_RUN.set(task_run_)
try:
with (
_logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
context_subtree() as span_tree_,
):
t0 = time.perf_counter()
if iscoroutinefunction(task):
task_output = cast(OutputT, await task(case.inputs))
task_output_ = cast(OutputT, await task(case.inputs))
else:
task_output = cast(OutputT, await to_thread.run_sync(task, case.inputs))
task_output_ = cast(OutputT, await to_thread.run_sync(task, case.inputs))
fallback_duration = time.perf_counter() - t0
finally:
_CURRENT_TASK_RUN.reset(token)
duration_ = _get_span_duration(task_span, fallback_duration)
return task_run_, task_output_, duration_, span_tree_
finally:
_CURRENT_TASK_RUN.reset(token)

async def _run_with_retries():
if retry:
async for attempt in retry:
with attempt:
return await _run_once()
# Note: the following line will be unreachable if retry is not None
return await _run_once()

task_run, task_output, duration, span_tree = await _run_with_retries()

if isinstance(span_tree, SpanTree): # pragma: no branch
# TODO: Question: Should we make this metric-attributes functionality more user-configurable in some way before merging?
Expand All @@ -863,6 +889,7 @@ async def _run_task(
if not isinstance(v, (int, float)):
continue
# TODO: Revisit this choice to strip the prefix..
# TODO: Use the span-tracking-of-metrics functionality to simplify this implementation
if k.startswith('gen_ai.usage.details.'):
task_run.increment_metric(k.removeprefix('gen_ai.usage.details.'), v)
elif k.startswith('gen_ai.usage.'):
Expand All @@ -874,7 +901,7 @@ async def _run_task(
metadata=case.metadata,
expected_output=case.expected_output,
output=task_output,
duration=_get_span_duration(task_span, fallback_duration),
duration=duration,
_span_tree=span_tree,
attributes=task_run.attributes,
metrics=task_run.metrics,
Expand All @@ -886,72 +913,92 @@ async def _run_task_and_evaluators(
case: Case[InputsT, OutputT, MetadataT],
report_case_name: str,
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
) -> ReportCase[InputsT, OutputT, MetadataT]:
retry: AsyncRetrying | None,
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
"""Run a task on a case and evaluate the results.

Args:
task: The task to run.
case: The case to run the task on.
report_case_name: The name to use for this case in the report.
dataset_evaluators: Evaluators from the dataset to apply to this case.
retry: The retry strategy to use for running the task.

Returns:
A ReportCase containing the evaluation results.
"""
with _logfire.span(
'case: {case_name}',
task_name=get_unwrapped_function_name(task),
case_name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
expected_output=case.expected_output,
) as case_span:
t0 = time.time()
scoring_context = await _run_task(task, case)

case_span.set_attribute('output', scoring_context.output)
case_span.set_attribute('task_duration', scoring_context.duration)
case_span.set_attribute('metrics', scoring_context.metrics)
case_span.set_attribute('attributes', scoring_context.attributes)

evaluators = case.evaluators + dataset_evaluators
evaluator_outputs: list[EvaluationResult] = []
if evaluators:
evaluator_outputs_by_task = await task_group_gather(
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
)
evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs]

assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))

context = case_span.context
if context is None: # pragma: no cover
trace_id = ''
span_id = ''
else:
trace_id = f'{context.trace_id:032x}'
span_id = f'{context.span_id:016x}'
fallback_duration = time.time() - t0

return ReportCase[InputsT, OutputT, MetadataT](
name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
expected_output=case.expected_output,
output=scoring_context.output,
metrics=scoring_context.metrics,
attributes=scoring_context.attributes,
scores=scores,
labels=labels,
assertions=assertions,
task_duration=scoring_context.duration,
total_duration=_get_span_duration(case_span, fallback_duration),
trace_id=trace_id,
span_id=span_id,
)
trace_id = ''
span_id = ''
try:
with _logfire.span(
'case: {case_name}',
task_name=get_unwrapped_function_name(task),
case_name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
expected_output=case.expected_output,
) as case_span:
context = case_span.context
if context is not None: # pragma: no cover
trace_id = f'{context.trace_id:032x}'
span_id = f'{context.span_id:016x}'

t0 = time.time()
scoring_context = await _run_task(task, case, retry)

case_span.set_attribute('output', scoring_context.output)
case_span.set_attribute('task_duration', scoring_context.duration)
case_span.set_attribute('metrics', scoring_context.metrics)
case_span.set_attribute('attributes', scoring_context.attributes)

evaluators = case.evaluators + dataset_evaluators
evaluator_outputs: list[EvaluationResult] = []
evaluator_failures: list[EvaluatorFailure] = []
if evaluators:
evaluator_outputs_by_task = await task_group_gather(
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
)
for outputs in evaluator_outputs_by_task:
if isinstance(outputs, EvaluatorFailure):
evaluator_failures.append(outputs)
else:
evaluator_outputs.extend(outputs)

assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))

fallback_duration = time.time() - t0

return ReportCase[InputsT, OutputT, MetadataT](
name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
expected_output=case.expected_output,
output=scoring_context.output,
metrics=scoring_context.metrics,
attributes=scoring_context.attributes,
scores=scores,
labels=labels,
assertions=assertions,
task_duration=scoring_context.duration,
total_duration=_get_span_duration(case_span, fallback_duration),
trace_id=trace_id,
span_id=span_id,
evaluator_failures=evaluator_failures,
)
except Exception as exc:
return ReportCaseFailure[InputsT, OutputT, MetadataT](
name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
expected_output=case.expected_output,
error_message=f'{type(exc).__name__}: {exc}',
error_stacktrace=traceback.format_exc(),
trace_id=trace_id,
span_id=span_id,
)


_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])
Expand Down
4 changes: 3 additions & 1 deletion pydantic_evals/pydantic_evals/evaluators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Python,
)
from .context import EvaluatorContext
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorOutput, EvaluatorSpec
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorFailure, EvaluatorOutput, EvaluatorSpec

__all__ = (
# common
Expand All @@ -27,6 +27,8 @@
'EvaluatorContext',
# evaluator
'Evaluator',
'EvaluationReason',
'EvaluatorFailure',
'EvaluatorOutput',
'EvaluatorSpec',
'EvaluationReason',
Expand Down
Loading
Loading