Skip to content

Commit 78304c2

Browse files
committed
Add a test
1 parent 7e57adb commit 78304c2

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/evals/test_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
import pytest
10+
import yaml
1011
from dirty_equals import HasRepr, IsNumber
1112
from inline_snapshot import snapshot
1213
from pydantic import BaseModel, TypeAdapter
@@ -819,10 +820,29 @@ async def test_serialization_to_yaml(example_dataset: Dataset[TaskInput, TaskOut
819820
# Test loading back
820821
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
821822
assert len(loaded_dataset.cases) == 2
823+
assert loaded_dataset.name == 'example'
822824
assert loaded_dataset.cases[0].name == 'case1'
823825
assert loaded_dataset.cases[0].inputs.query == 'What is 2+2?'
824826

825827

828+
async def test_deserializing_without_name(
829+
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path
830+
):
831+
"""Test serializing a dataset to YAML."""
832+
# Save the dataset
833+
yaml_path = tmp_path / 'test_cases.yaml'
834+
example_dataset.to_file(yaml_path)
835+
836+
# Rewrite the file _without_ a name to test deserializing a name-less file
837+
obj = yaml.safe_load(yaml_path.read_text())
838+
obj.pop('name', None)
839+
yaml_path.write_text(yaml.dump(obj))
840+
841+
# Test loading results in the name coming from the filename stem
842+
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
843+
assert loaded_dataset.name == 'test_cases'
844+
845+
826846
async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path):
827847
"""Test serializing a dataset to JSON."""
828848
json_path = tmp_path / 'test_cases.json'
@@ -854,6 +874,7 @@ def test_serialization_errors(tmp_path: Path):
854874
async def test_from_text():
855875
"""Test creating a dataset from text."""
856876
dataset_dict = {
877+
'name': 'my dataset',
857878
'cases': [
858879
{
859880
'name': '1',
@@ -873,6 +894,7 @@ async def test_from_text():
873894
}
874895

875896
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
897+
assert loaded_dataset.name == 'my dataset'
876898
assert loaded_dataset.cases == snapshot(
877899
[
878900
Case(

0 commit comments

Comments
 (0)