Skip to content

Commit 51fec9f

Browse files
authored
Update evals attributes (#2924)
1 parent efaf6be commit 51fec9f

File tree

2 files changed

+80
-19
lines changed

2 files changed

+80
-19
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class _DatasetModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forb
9898

9999
# $schema is included to avoid validation fails from the `$schema` key, see `_add_json_schema` below for context
100100
json_schema_path: str | None = Field(default=None, alias='$schema')
101+
name: str | None = None
101102
cases: list[_CaseModel[InputsT, OutputT, MetadataT]]
102103
evaluators: list[EvaluatorSpec] = Field(default_factory=list)
103104

@@ -218,6 +219,8 @@ async def main():
218219
```
219220
"""
220221

222+
name: str | None = None
223+
"""Optional name of the dataset."""
221224
cases: list[Case[InputsT, OutputT, MetadataT]]
222225
"""List of test cases in the dataset."""
223226
evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = []
@@ -226,12 +229,14 @@ async def main():
226229
def __init__(
227230
self,
228231
*,
232+
name: str | None = None,
229233
cases: Sequence[Case[InputsT, OutputT, MetadataT]],
230234
evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (),
231235
):
232236
"""Initialize a new dataset with test cases and optional evaluators.
233237
234238
Args:
239+
name: Optional name for the dataset.
235240
cases: Sequence of test cases to include in the dataset.
236241
evaluators: Optional sequence of evaluators to apply to all cases in the dataset.
237242
"""
@@ -244,10 +249,12 @@ def __init__(
244249
case_names.add(case.name)
245250

246251
super().__init__(
252+
name=name,
247253
cases=cases,
248254
evaluators=list(evaluators),
249255
)
250256

257+
# TODO in v2: Make everything not required keyword-only
251258
async def evaluate(
252259
self,
253260
task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
@@ -256,6 +263,8 @@ async def evaluate(
256263
progress: bool = True,
257264
retry_task: RetryConfig | None = None,
258265
retry_evaluators: RetryConfig | None = None,
266+
*,
267+
task_name: str | None = None,
259268
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
260269
"""Evaluates the test cases in the dataset using the given task.
261270
@@ -265,28 +274,38 @@ async def evaluate(
265274
Args:
266275
task: The task to evaluate. This should be a callable that takes the inputs of the case
267276
and returns the output.
268-
name: The name of the task being evaluated, this is used to identify the task in the report.
269-
If omitted, the name of the task function will be used.
277+
name: The name of the experiment being run, this is used to identify the experiment in the report.
278+
If omitted, the task_name will be used; if that is not specified, the name of the task function is used.
270279
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
271280
If None, all cases will be evaluated concurrently.
272281
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
273282
retry_task: Optional retry configuration for the task execution.
274283
retry_evaluators: Optional retry configuration for evaluator execution.
284+
task_name: Optional override to the name of the task being executed, otherwise the name of the task
285+
function will be used.
275286
276287
Returns:
277288
A report containing the results of the evaluation.
278289
"""
279-
name = name or get_unwrapped_function_name(task)
290+
task_name = task_name or get_unwrapped_function_name(task)
291+
name = name or task_name
280292
total_cases = len(self.cases)
281293
progress_bar = Progress() if progress else None
282294

283295
limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
284296

285297
with (
286-
logfire_span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span,
298+
logfire_span(
299+
'evaluate {name}',
300+
name=name,
301+
task_name=task_name,
302+
dataset_name=self.name,
303+
n_cases=len(self.cases),
304+
**{'gen_ai.operation.name': 'experiment'}, # pyright: ignore[reportArgumentType]
305+
) as eval_span,
287306
progress_bar or nullcontext(),
288307
):
289-
task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None
308+
task_id = progress_bar.add_task(f'Evaluating {task_name}', total=total_cases) if progress_bar else None
290309

291310
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
292311
async with limiter:
@@ -357,7 +376,7 @@ def evaluate_sync(
357376
return get_event_loop().run_until_complete(
358377
self.evaluate(
359378
task,
360-
name=name,
379+
task_name=name,
361380
max_concurrency=max_concurrency,
362381
progress=progress,
363382
retry_task=retry_task,
@@ -474,7 +493,7 @@ def from_file(
474493

475494
raw = Path(path).read_text()
476495
try:
477-
return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types)
496+
return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types, default_name=path.stem)
478497
except ValidationError as e: # pragma: no cover
479498
raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e
480499

@@ -484,6 +503,8 @@ def from_text(
484503
contents: str,
485504
fmt: Literal['yaml', 'json'] = 'yaml',
486505
custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
506+
*,
507+
default_name: str | None = None,
487508
) -> Self:
488509
"""Load a dataset from a string.
489510
@@ -492,6 +513,7 @@ def from_text(
492513
fmt: Format of the content. Must be either 'yaml' or 'json'.
493514
custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
494515
These are additional evaluators beyond the default ones.
516+
default_name: Default name of the dataset, to be used if not specified in the serialized contents.
495517
496518
Returns:
497519
A new Dataset instance parsed from the string.
@@ -501,24 +523,27 @@ def from_text(
501523
"""
502524
if fmt == 'yaml':
503525
loaded = yaml.safe_load(contents)
504-
return cls.from_dict(loaded, custom_evaluator_types)
526+
return cls.from_dict(loaded, custom_evaluator_types, default_name=default_name)
505527
else:
506528
dataset_model_type = cls._serialization_type()
507529
dataset_model = dataset_model_type.model_validate_json(contents)
508-
return cls._from_dataset_model(dataset_model, custom_evaluator_types)
530+
return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name)
509531

510532
@classmethod
511533
def from_dict(
512534
cls,
513535
data: dict[str, Any],
514536
custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
537+
*,
538+
default_name: str | None = None,
515539
) -> Self:
516540
"""Load a dataset from a dictionary.
517541
518542
Args:
519543
data: Dictionary representation of the dataset.
520544
custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
521545
These are additional evaluators beyond the default ones.
546+
default_name: Default name of the dataset, to be used if not specified in the data.
522547
523548
Returns:
524549
A new Dataset instance created from the dictionary.
@@ -528,19 +553,21 @@ def from_dict(
528553
"""
529554
dataset_model_type = cls._serialization_type()
530555
dataset_model = dataset_model_type.model_validate(data)
531-
return cls._from_dataset_model(dataset_model, custom_evaluator_types)
556+
return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name)
532557

533558
@classmethod
534559
def _from_dataset_model(
535560
cls,
536561
dataset_model: _DatasetModel[InputsT, OutputT, MetadataT],
537562
custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
563+
default_name: str | None = None,
538564
) -> Self:
539565
"""Create a Dataset from a _DatasetModel.
540566
541567
Args:
542568
dataset_model: The _DatasetModel to convert.
543569
custom_evaluator_types: Custom evaluator classes to register for deserialization.
570+
default_name: Default name of the dataset, to be used if the value is `None` in the provided model.
544571
545572
Returns:
546573
A new Dataset instance created from the _DatasetModel.
@@ -577,7 +604,9 @@ def _from_dataset_model(
577604
cases.append(row)
578605
if errors:
579606
raise ExceptionGroup(f'{len(errors)} error(s) loading evaluators from registry', errors[:3])
580-
result = cls(cases=cases)
607+
result = cls(name=dataset_model.name, cases=cases)
608+
if result.name is None:
609+
result.name = default_name
581610
result.evaluators = dataset_evaluators
582611
return result
583612

tests/evals/test_dataset.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
import pytest
10+
import yaml
1011
from dirty_equals import HasRepr, IsNumber
1112
from inline_snapshot import snapshot
1213
from pydantic import BaseModel, TypeAdapter
@@ -106,7 +107,7 @@ def example_cases() -> list[Case[TaskInput, TaskOutput, TaskMetadata]]:
106107
def example_dataset(
107108
example_cases: list[Case[TaskInput, TaskOutput, TaskMetadata]],
108109
) -> Dataset[TaskInput, TaskOutput, TaskMetadata]:
109-
return Dataset[TaskInput, TaskOutput, TaskMetadata](cases=example_cases)
110+
return Dataset[TaskInput, TaskOutput, TaskMetadata](name='example', cases=example_cases)
110111

111112

112113
@pytest.fixture
@@ -820,10 +821,29 @@ async def test_serialization_to_yaml(example_dataset: Dataset[TaskInput, TaskOut
820821
# Test loading back
821822
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
822823
assert len(loaded_dataset.cases) == 2
824+
assert loaded_dataset.name == 'example'
823825
assert loaded_dataset.cases[0].name == 'case1'
824826
assert loaded_dataset.cases[0].inputs.query == 'What is 2+2?'
825827

826828

829+
async def test_deserializing_without_name(
830+
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path
831+
):
832+
"""Test serializing a dataset to YAML."""
833+
# Save the dataset
834+
yaml_path = tmp_path / 'test_cases.yaml'
835+
example_dataset.to_file(yaml_path)
836+
837+
# Rewrite the file _without_ a name to test deserializing a name-less file
838+
obj = yaml.safe_load(yaml_path.read_text())
839+
obj.pop('name', None)
840+
yaml_path.write_text(yaml.dump(obj))
841+
842+
# Test loading results in the name coming from the filename stem
843+
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
844+
assert loaded_dataset.name == 'test_cases'
845+
846+
827847
async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path):
828848
"""Test serializing a dataset to JSON."""
829849
json_path = tmp_path / 'test_cases.json'
@@ -855,6 +875,7 @@ def test_serialization_errors(tmp_path: Path):
855875
async def test_from_text():
856876
"""Test creating a dataset from text."""
857877
dataset_dict = {
878+
'name': 'my dataset',
858879
'cases': [
859880
{
860881
'name': '1',
@@ -874,6 +895,7 @@ async def test_from_text():
874895
}
875896

876897
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
898+
assert loaded_dataset.name == 'my dataset'
877899
assert loaded_dataset.cases == snapshot(
878900
[
879901
Case(
@@ -1241,7 +1263,7 @@ async def test_dataset_evaluate_with_custom_name(example_dataset: Dataset[TaskIn
12411263
async def task(inputs: TaskInput) -> TaskOutput:
12421264
return TaskOutput(answer=inputs.query.upper())
12431265

1244-
report = await example_dataset.evaluate(task, name='custom_task')
1266+
report = await example_dataset.evaluate(task, task_name='custom_task')
12451267
assert report.name == 'custom_task'
12461268

12471269

@@ -1491,16 +1513,26 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput:
14911513
(
14921514
'evaluate {name}',
14931515
{
1494-
'name': 'mock_async_task',
1495-
'n_cases': 2,
14961516
'assertion_pass_rate': 1.0,
1497-
'logfire.msg_template': 'evaluate {name}',
1498-
'logfire.msg': 'evaluate mock_async_task',
1499-
'logfire.span_type': 'span',
1517+
'dataset_name': 'example',
1518+
'gen_ai.operation.name': 'experiment',
15001519
'logfire.json_schema': {
1520+
'properties': {
1521+
'assertion_pass_rate': {},
1522+
'dataset_name': {},
1523+
'gen_ai.operation.name': {},
1524+
'n_cases': {},
1525+
'name': {},
1526+
'task_name': {},
1527+
},
15011528
'type': 'object',
1502-
'properties': {'name': {}, 'n_cases': {}, 'assertion_pass_rate': {}},
15031529
},
1530+
'logfire.msg': 'evaluate mock_async_task',
1531+
'logfire.msg_template': 'evaluate {name}',
1532+
'logfire.span_type': 'span',
1533+
'n_cases': 2,
1534+
'name': 'mock_async_task',
1535+
'task_name': 'mock_async_task',
15041536
},
15051537
),
15061538
(

0 commit comments

Comments
 (0)