|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from contextlib import ExitStack as DoesNotRaise |
| 4 | +from pathlib import Path |
4 | 5 |
|
| 6 | +import numpy as np |
5 | 7 | import pytest |
6 | 8 |
|
7 | | -from supervision import DetectionDataset |
8 | | -from tests.helpers import _create_detections |
| 9 | +from supervision import DetectionDataset, Detections |
| 10 | +from supervision.config import CLASS_NAME_DATA_FIELD |
| 11 | +from tests.helpers import _create_detections, create_yolo_dataset |
9 | 12 |
|
10 | 13 |
|
11 | 14 | @pytest.mark.parametrize( |
@@ -187,3 +190,97 @@ def test_dataset_merge( |
187 | 190 | with exception: |
188 | 191 | result = DetectionDataset.merge(dataset_list=dataset_list) |
189 | 192 | assert result == expected_result |
| 193 | + |
| 194 | + |
| 195 | +class TestClassNamePopulation: |
| 196 | + """Verify that DetectionDataset populates CLASS_NAME_DATA_FIELD on init.""" |
| 197 | + |
| 198 | + def test_class_name_populated_on_init(self) -> None: |
| 199 | + """Basic case: class_name data field is set from classes and class_id.""" |
| 200 | + dataset = DetectionDataset( |
| 201 | + classes=["dog", "cat"], |
| 202 | + images=["img1.png"], |
| 203 | + annotations={ |
| 204 | + "img1.png": _create_detections( |
| 205 | + xyxy=[[0, 0, 10, 10], [20, 20, 30, 30]], |
| 206 | + class_id=[0, 1], |
| 207 | + ), |
| 208 | + }, |
| 209 | + ) |
| 210 | + annotation = dataset.annotations["img1.png"] |
| 211 | + assert CLASS_NAME_DATA_FIELD in annotation.data |
| 212 | + np.testing.assert_array_equal( |
| 213 | + annotation.data[CLASS_NAME_DATA_FIELD], |
| 214 | + np.array(["dog", "cat"]), |
| 215 | + ) |
| 216 | + |
| 217 | + def test_class_name_with_empty_annotations(self) -> None: |
| 218 | + """Empty Detections should not raise an error.""" |
| 219 | + dataset = DetectionDataset( |
| 220 | + classes=["dog"], |
| 221 | + images=["img1.png"], |
| 222 | + annotations={"img1.png": Detections.empty()}, |
| 223 | + ) |
| 224 | + annotation = dataset.annotations["img1.png"] |
| 225 | + assert CLASS_NAME_DATA_FIELD in annotation.data |
| 226 | + assert len(annotation.data[CLASS_NAME_DATA_FIELD]) == 0 |
| 227 | + |
| 228 | + def test_class_name_with_empty_classes(self) -> None: |
| 229 | + """When classes is empty, class_name should not be populated.""" |
| 230 | + dataset = DetectionDataset( |
| 231 | + classes=[], |
| 232 | + images=[], |
| 233 | + annotations={}, |
| 234 | + ) |
| 235 | + assert len(dataset.annotations) == 0 |
| 236 | + |
| 237 | + def test_class_name_after_merge(self) -> None: |
| 238 | + """After merging datasets, class_name must match remapped class_id.""" |
| 239 | + ds1 = DetectionDataset( |
| 240 | + classes=["dog", "person"], |
| 241 | + images=["img1.png"], |
| 242 | + annotations={ |
| 243 | + "img1.png": _create_detections(xyxy=[[0, 0, 10, 10]], class_id=[0]), |
| 244 | + }, |
| 245 | + ) |
| 246 | + ds2 = DetectionDataset( |
| 247 | + classes=["cat"], |
| 248 | + images=["img2.png"], |
| 249 | + annotations={ |
| 250 | + "img2.png": _create_detections(xyxy=[[0, 0, 10, 10]], class_id=[0]), |
| 251 | + }, |
| 252 | + ) |
| 253 | + merged = DetectionDataset.merge([ds1, ds2]) |
| 254 | + |
| 255 | + # merged.classes is ["cat", "dog", "person"] |
| 256 | + # ds1's dog (0) -> dog (1), ds2's cat (0) -> cat (0) |
| 257 | + ann1 = merged.annotations["img1.png"] |
| 258 | + assert CLASS_NAME_DATA_FIELD in ann1.data |
| 259 | + np.testing.assert_array_equal( |
| 260 | + ann1.data[CLASS_NAME_DATA_FIELD], np.array(["dog"]) |
| 261 | + ) |
| 262 | + |
| 263 | + ann2 = merged.annotations["img2.png"] |
| 264 | + assert CLASS_NAME_DATA_FIELD in ann2.data |
| 265 | + np.testing.assert_array_equal( |
| 266 | + ann2.data[CLASS_NAME_DATA_FIELD], np.array(["cat"]) |
| 267 | + ) |
| 268 | + |
| 269 | + def test_class_name_from_yolo(self, tmp_path: Path) -> None: |
| 270 | + """Integration test: from_yolo should produce class_name data.""" |
| 271 | + dataset_info = create_yolo_dataset( |
| 272 | + str(tmp_path), num_images=2, classes=["cat", "dog"] |
| 273 | + ) |
| 274 | + dataset = DetectionDataset.from_yolo( |
| 275 | + images_directory_path=dataset_info["images_dir"], |
| 276 | + annotations_directory_path=dataset_info["labels_dir"], |
| 277 | + data_yaml_path=dataset_info["data_yaml_path"], |
| 278 | + ) |
| 279 | + |
| 280 | + for _, annotation in dataset.annotations.items(): |
| 281 | + if annotation.class_id is not None and len(annotation.class_id) > 0: |
| 282 | + assert CLASS_NAME_DATA_FIELD in annotation.data |
| 283 | + expected_names = np.array(dataset.classes)[annotation.class_id] |
| 284 | + np.testing.assert_array_equal( |
| 285 | + annotation.data[CLASS_NAME_DATA_FIELD], expected_names |
| 286 | + ) |
0 commit comments