Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions pydantic_evals/pydantic_evals/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class _DatasetModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forb

# $schema is included to avoid validation fails from the `$schema` key, see `_add_json_schema` below for context
json_schema_path: str | None = Field(default=None, alias='$schema')
name: str | None = None
cases: list[_CaseModel[InputsT, OutputT, MetadataT]]
evaluators: list[EvaluatorSpec] = Field(default_factory=list)

Expand Down Expand Up @@ -218,6 +219,8 @@ async def main():
```
"""

name: str | None = None
"""Optional name of the dataset."""
cases: list[Case[InputsT, OutputT, MetadataT]]
"""List of test cases in the dataset."""
evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = []
Expand All @@ -226,12 +229,14 @@ async def main():
def __init__(
self,
*,
name: str | None = None,
cases: Sequence[Case[InputsT, OutputT, MetadataT]],
evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (),
):
"""Initialize a new dataset with test cases and optional evaluators.

Args:
name: Optional name for the dataset.
cases: Sequence of test cases to include in the dataset.
evaluators: Optional sequence of evaluators to apply to all cases in the dataset.
"""
Expand All @@ -244,10 +249,12 @@ def __init__(
case_names.add(case.name)

super().__init__(
name=name,
cases=cases,
evaluators=list(evaluators),
)

# TODO in v2: Make everything not required keyword-only
async def evaluate(
self,
task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
Expand All @@ -256,6 +263,8 @@ async def evaluate(
progress: bool = True,
retry_task: RetryConfig | None = None,
retry_evaluators: RetryConfig | None = None,
*,
task_name: str | None = None,
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
"""Evaluates the test cases in the dataset using the given task.

Expand All @@ -265,28 +274,38 @@ async def evaluate(
Args:
task: The task to evaluate. This should be a callable that takes the inputs of the case
and returns the output.
name: The name of the task being evaluated, this is used to identify the task in the report.
If omitted, the name of the task function will be used.
name: The name of the experiment being run, this is used to identify the experiment in the report.
If omitted, the task_name will be used; if that is not specified, the name of the task function is used.
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
If None, all cases will be evaluated concurrently.
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
retry_task: Optional retry configuration for the task execution.
retry_evaluators: Optional retry configuration for evaluator execution.
task_name: Optional override to the name of the task being executed, otherwise the name of the task
function will be used.

Returns:
A report containing the results of the evaluation.
"""
name = name or get_unwrapped_function_name(task)
task_name = task_name or get_unwrapped_function_name(task)
name = name or task_name
total_cases = len(self.cases)
progress_bar = Progress() if progress else None

limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()

with (
logfire_span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span,
logfire_span(
'evaluate {name}',
name=name,
task_name=task_name,
dataset_name=self.name,
n_cases=len(self.cases),
**{'gen_ai.operation.name': 'experiment'}, # pyright: ignore[reportArgumentType]
) as eval_span,
progress_bar or nullcontext(),
):
task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None
task_id = progress_bar.add_task(f'Evaluating {task_name}', total=total_cases) if progress_bar else None

async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
async with limiter:
Expand Down Expand Up @@ -357,7 +376,7 @@ def evaluate_sync(
return get_event_loop().run_until_complete(
self.evaluate(
task,
name=name,
task_name=name,
max_concurrency=max_concurrency,
progress=progress,
retry_task=retry_task,
Expand Down Expand Up @@ -474,7 +493,7 @@ def from_file(

raw = Path(path).read_text()
try:
return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types)
return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types, default_name=path.stem)
except ValidationError as e: # pragma: no cover
raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e

Expand All @@ -484,6 +503,8 @@ def from_text(
contents: str,
fmt: Literal['yaml', 'json'] = 'yaml',
custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
*,
default_name: str | None = None,
) -> Self:
"""Load a dataset from a string.

Expand All @@ -492,6 +513,7 @@ def from_text(
fmt: Format of the content. Must be either 'yaml' or 'json'.
custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
These are additional evaluators beyond the default ones.
default_name: Default name of the dataset, to be used if not specified in the serialized contents.

Returns:
A new Dataset instance parsed from the string.
Expand All @@ -501,24 +523,27 @@ def from_text(
"""
if fmt == 'yaml':
loaded = yaml.safe_load(contents)
return cls.from_dict(loaded, custom_evaluator_types)
return cls.from_dict(loaded, custom_evaluator_types, default_name=default_name)
else:
dataset_model_type = cls._serialization_type()
dataset_model = dataset_model_type.model_validate_json(contents)
return cls._from_dataset_model(dataset_model, custom_evaluator_types)
return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name)

@classmethod
def from_dict(
cls,
data: dict[str, Any],
custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
*,
default_name: str | None = None,
) -> Self:
"""Load a dataset from a dictionary.

Args:
data: Dictionary representation of the dataset.
custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
These are additional evaluators beyond the default ones.
default_name: Default name of the dataset, to be used if not specified in the data.

Returns:
A new Dataset instance created from the dictionary.
Expand All @@ -528,19 +553,21 @@ def from_dict(
"""
dataset_model_type = cls._serialization_type()
dataset_model = dataset_model_type.model_validate(data)
return cls._from_dataset_model(dataset_model, custom_evaluator_types)
return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name)

@classmethod
def _from_dataset_model(
cls,
dataset_model: _DatasetModel[InputsT, OutputT, MetadataT],
custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
default_name: str | None = None,
) -> Self:
"""Create a Dataset from a _DatasetModel.

Args:
dataset_model: The _DatasetModel to convert.
custom_evaluator_types: Custom evaluator classes to register for deserialization.
default_name: Default name of the dataset, to be used if the value is `None` in the provided model.

Returns:
A new Dataset instance created from the _DatasetModel.
Expand Down Expand Up @@ -577,7 +604,9 @@ def _from_dataset_model(
cases.append(row)
if errors:
raise ExceptionGroup(f'{len(errors)} error(s) loading evaluators from registry', errors[:3])
result = cls(cases=cases)
result = cls(name=dataset_model.name, cases=cases)
if result.name is None:
result.name = default_name
result.evaluators = dataset_evaluators
return result

Expand Down
48 changes: 40 additions & 8 deletions tests/evals/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import pytest
import yaml
from dirty_equals import HasRepr, IsNumber
from inline_snapshot import snapshot
from pydantic import BaseModel, TypeAdapter
Expand Down Expand Up @@ -106,7 +107,7 @@ def example_cases() -> list[Case[TaskInput, TaskOutput, TaskMetadata]]:
def example_dataset(
example_cases: list[Case[TaskInput, TaskOutput, TaskMetadata]],
) -> Dataset[TaskInput, TaskOutput, TaskMetadata]:
return Dataset[TaskInput, TaskOutput, TaskMetadata](cases=example_cases)
return Dataset[TaskInput, TaskOutput, TaskMetadata](name='example', cases=example_cases)


@pytest.fixture
Expand Down Expand Up @@ -820,10 +821,29 @@ async def test_serialization_to_yaml(example_dataset: Dataset[TaskInput, TaskOut
# Test loading back
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
assert len(loaded_dataset.cases) == 2
assert loaded_dataset.name == 'example'
assert loaded_dataset.cases[0].name == 'case1'
assert loaded_dataset.cases[0].inputs.query == 'What is 2+2?'


async def test_deserializing_without_name(
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path
):
"""Test serializing a dataset to YAML."""
# Save the dataset
yaml_path = tmp_path / 'test_cases.yaml'
example_dataset.to_file(yaml_path)

# Rewrite the file _without_ a name to test deserializing a name-less file
obj = yaml.safe_load(yaml_path.read_text())
obj.pop('name', None)
yaml_path.write_text(yaml.dump(obj))

# Test loading results in the name coming from the filename stem
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
assert loaded_dataset.name == 'test_cases'


async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path):
"""Test serializing a dataset to JSON."""
json_path = tmp_path / 'test_cases.json'
Expand Down Expand Up @@ -855,6 +875,7 @@ def test_serialization_errors(tmp_path: Path):
async def test_from_text():
"""Test creating a dataset from text."""
dataset_dict = {
'name': 'my dataset',
'cases': [
{
'name': '1',
Expand All @@ -874,6 +895,7 @@ async def test_from_text():
}

loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
assert loaded_dataset.name == 'my dataset'
assert loaded_dataset.cases == snapshot(
[
Case(
Expand Down Expand Up @@ -1241,7 +1263,7 @@ async def test_dataset_evaluate_with_custom_name(example_dataset: Dataset[TaskIn
async def task(inputs: TaskInput) -> TaskOutput:
return TaskOutput(answer=inputs.query.upper())

report = await example_dataset.evaluate(task, name='custom_task')
report = await example_dataset.evaluate(task, task_name='custom_task')
assert report.name == 'custom_task'


Expand Down Expand Up @@ -1491,16 +1513,26 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput:
(
'evaluate {name}',
{
'name': 'mock_async_task',
'n_cases': 2,
'assertion_pass_rate': 1.0,
'logfire.msg_template': 'evaluate {name}',
'logfire.msg': 'evaluate mock_async_task',
'logfire.span_type': 'span',
'dataset_name': 'example',
'gen_ai.operation.name': 'experiment',
'logfire.json_schema': {
'properties': {
'assertion_pass_rate': {},
'dataset_name': {},
'gen_ai.operation.name': {},
'n_cases': {},
'name': {},
'task_name': {},
},
'type': 'object',
'properties': {'name': {}, 'n_cases': {}, 'assertion_pass_rate': {}},
},
'logfire.msg': 'evaluate mock_async_task',
'logfire.msg_template': 'evaluate {name}',
'logfire.span_type': 'span',
'n_cases': 2,
'name': 'mock_async_task',
'task_name': 'mock_async_task',
},
),
(
Expand Down