Skip to content

Commit 036aacc

Browse files
committed
Add a test
1 parent e7782db commit 036aacc

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/evals/test_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
import pytest
10+
import yaml
1011
from dirty_equals import HasRepr, IsNumber
1112
from inline_snapshot import snapshot
1213
from pydantic import BaseModel, TypeAdapter
@@ -820,10 +821,29 @@ async def test_serialization_to_yaml(example_dataset: Dataset[TaskInput, TaskOut
820821
# Test loading back
821822
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
822823
assert len(loaded_dataset.cases) == 2
824+
assert loaded_dataset.name == 'example'
823825
assert loaded_dataset.cases[0].name == 'case1'
824826
assert loaded_dataset.cases[0].inputs.query == 'What is 2+2?'
825827

826828

829+
async def test_deserializing_without_name(
830+
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path
831+
):
832+
"""Test serializing a dataset to YAML."""
833+
# Save the dataset
834+
yaml_path = tmp_path / 'test_cases.yaml'
835+
example_dataset.to_file(yaml_path)
836+
837+
# Rewrite the file _without_ a name to test deserializing a name-less file
838+
obj = yaml.safe_load(yaml_path.read_text())
839+
obj.pop('name', None)
840+
yaml_path.write_text(yaml.dump(obj))
841+
842+
# Test loading results in the name coming from the filename stem
843+
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path)
844+
assert loaded_dataset.name == 'test_cases'
845+
846+
827847
async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path):
828848
"""Test serializing a dataset to JSON."""
829849
json_path = tmp_path / 'test_cases.json'
@@ -855,6 +875,7 @@ def test_serialization_errors(tmp_path: Path):
855875
async def test_from_text():
856876
"""Test creating a dataset from text."""
857877
dataset_dict = {
878+
'name': 'my dataset',
858879
'cases': [
859880
{
860881
'name': '1',
@@ -874,6 +895,7 @@ async def test_from_text():
874895
}
875896

876897
loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
898+
assert loaded_dataset.name == 'my dataset'
877899
assert loaded_dataset.cases == snapshot(
878900
[
879901
Case(

0 commit comments

Comments
 (0)