diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 5fa2b7b037..b3d968632e 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1435,6 +1435,55 @@ def with_nmm( return Detections.merge(result) + def transform(self, dataset, class_mapping: Optional[dict] = None) -> Detections: + """ + Remap and filter detections to match a target dataset's class set. + + Args: + dataset: An object with a .classes attribute (list of class names). + class_mapping (dict, optional): Mapping from model class names to + dataset class names. + + Returns: + Detections: A new Detections object with class names and IDs + remapped and filtered. + """ + # Get class names for each detection + class_names = self.data.get("class_name") + if class_names is None: + raise ValueError( + "Detections must have 'class_name' in .data to use transform()." + ) + class_names = np.array(class_names) + # Remap class names if mapping is provided + if class_mapping is not None: + class_names = np.array( + [class_mapping.get(name, name) for name in class_names] + ) + # Filter out detections whose class is not in dataset.classes + keep = np.isin(class_names, dataset.classes) + # Remap class_id to match dataset.classes + new_class_id = np.array( + [dataset.classes.index(name) for name in class_names[keep]] + ) + # Build new Detections object + return Detections( + xyxy=self.xyxy[keep], + mask=self.mask[keep] if self.mask is not None else None, + confidence=self.confidence[keep] if self.confidence is not None else None, + class_id=new_class_id, + tracker_id=self.tracker_id[keep] if self.tracker_id is not None else None, + data={ + k: ( + np.array(v)[keep] + if isinstance(v, (list, np.ndarray)) and len(v) == len(self) + else v + ) + for k, v in self.data.items() + }, + metadata=self.metadata.copy(), + ) + def merge_inner_detection_object_pair( detections_1: Detections, detections_2: Detections diff --git a/test/detection/test_transform.py b/test/detection/test_transform.py new file mode 100644 index 0000000000..ed3c9c6af7 --- /dev/null +++ b/test/detection/test_transform.py @@ -0,0 +1,55 @@ +from types import SimpleNamespace + +import numpy as np +import pytest + +from supervision.detection.core import Detections + + +def test_transform_remap_and_filter(): + # Simulate a model that predicts 'dog', 'cat', 'eagle', 'car' + det = Detections( + xyxy=np.array([[0, 0, 1, 1], [1, 1, 2, 2], [2, 2, 3, 3], [3, 3, 4, 4]]), + class_id=np.array([0, 1, 2, 3]), + confidence=np.array([0.9, 0.8, 0.7, 0.6]), + data={"class_name": np.array(["dog", "cat", "eagle", "car"])}, + ) + # Dataset expects 'animal', 'bird', 'car' (in that order) + dataset = SimpleNamespace(classes=["animal", "bird", "car"]) + class_mapping = {"dog": "animal", "cat": "animal", "eagle": "bird"} + det2 = det.transform(dataset, class_mapping=class_mapping) + # Only 'dog', 'cat', 'eagle', 'car' should remain, + # but 'dog' and 'cat' become 'animal', 'eagle' becomes 'bird' + assert set(det2.data["class_name"]) <= set([*dataset.classes, "car"]) + assert all([name in dataset.classes for name in det2.data["class_name"]]) + # class_id should be remapped to dataset.classes indices + for name, cid in zip(det2.data["class_name"], det2.class_id): + assert dataset.classes[cid] == name + # Only 'dog', 'cat', 'eagle', 'car' remain, but 'car' is already in dataset.classes + assert len(det2) == 4 + + +def test_transform_no_class_mapping(): + det = Detections( + xyxy=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]), + class_id=np.array([0, 1]), + confidence=np.array([0.9, 0.8]), + data={"class_name": np.array(["car", "truck"])}, + ) + dataset = SimpleNamespace(classes=["car"]) + det2 = det.transform(dataset) + assert len(det2) == 1 + assert det2.data["class_name"][0] == "car" + assert det2.class_id[0] == 0 + + +def test_transform_raises_without_class_name(): + det = Detections( + xyxy=np.array([[0, 0, 1, 1]]), + class_id=np.array([0]), + confidence=np.array([0.9]), + data={}, + ) + dataset = SimpleNamespace(classes=["car"]) + with pytest.raises(ValueError): + det.transform(dataset)