Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,55 @@ def with_nmm(

return Detections.merge(result)

def transform(self, dataset, class_mapping: Optional[dict] = None) -> Detections:
"""
Remap and filter detections to match a target dataset's class set.

Args:
dataset: An object with a .classes attribute (list of class names).
class_mapping (dict, optional): Mapping from model class names to
dataset class names.

Returns:
Detections: A new Detections object with class names and IDs
remapped and filtered.
"""
# Get class names for each detection
class_names = self.data.get("class_name")
if class_names is None:
raise ValueError(
"Detections must have 'class_name' in .data to use transform()."
)
class_names = np.array(class_names)
# Remap class names if mapping is provided
if class_mapping is not None:
class_names = np.array(
[class_mapping.get(name, name) for name in class_names]
)
# Filter out detections whose class is not in dataset.classes
keep = np.isin(class_names, dataset.classes)
# Remap class_id to match dataset.classes
new_class_id = np.array(
[dataset.classes.index(name) for name in class_names[keep]]
)
# Build new Detections object
return Detections(
xyxy=self.xyxy[keep],
mask=self.mask[keep] if self.mask is not None else None,
confidence=self.confidence[keep] if self.confidence is not None else None,
class_id=new_class_id,
tracker_id=self.tracker_id[keep] if self.tracker_id is not None else None,
data={
k: (
np.array(v)[keep]
if isinstance(v, (list, np.ndarray)) and len(v) == len(self)
else v
)
for k, v in self.data.items()
},
metadata=self.metadata.copy(),
)


def merge_inner_detection_object_pair(
detections_1: Detections, detections_2: Detections
Expand Down
55 changes: 55 additions & 0 deletions test/detection/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from types import SimpleNamespace

import numpy as np
import pytest

from supervision.detection.core import Detections


def test_transform_remap_and_filter():
# Simulate a model that predicts 'dog', 'cat', 'eagle', 'car'
det = Detections(
xyxy=np.array([[0, 0, 1, 1], [1, 1, 2, 2], [2, 2, 3, 3], [3, 3, 4, 4]]),
class_id=np.array([0, 1, 2, 3]),
confidence=np.array([0.9, 0.8, 0.7, 0.6]),
data={"class_name": np.array(["dog", "cat", "eagle", "car"])},
)
# Dataset expects 'animal', 'bird', 'car' (in that order)
dataset = SimpleNamespace(classes=["animal", "bird", "car"])
class_mapping = {"dog": "animal", "cat": "animal", "eagle": "bird"}
det2 = det.transform(dataset, class_mapping=class_mapping)
# Only 'dog', 'cat', 'eagle', 'car' should remain,
# but 'dog' and 'cat' become 'animal', 'eagle' becomes 'bird'
assert set(det2.data["class_name"]) <= set([*dataset.classes, "car"])
assert all([name in dataset.classes for name in det2.data["class_name"]])
# class_id should be remapped to dataset.classes indices
for name, cid in zip(det2.data["class_name"], det2.class_id):
assert dataset.classes[cid] == name
# Only 'dog', 'cat', 'eagle', 'car' remain, but 'car' is already in dataset.classes
assert len(det2) == 4


def test_transform_no_class_mapping():
det = Detections(
xyxy=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
class_id=np.array([0, 1]),
confidence=np.array([0.9, 0.8]),
data={"class_name": np.array(["car", "truck"])},
)
dataset = SimpleNamespace(classes=["car"])
det2 = det.transform(dataset)
assert len(det2) == 1
assert det2.data["class_name"][0] == "car"
assert det2.class_id[0] == 0


def test_transform_raises_without_class_name():
det = Detections(
xyxy=np.array([[0, 0, 1, 1]]),
class_id=np.array([0]),
confidence=np.array([0.9]),
data={},
)
dataset = SimpleNamespace(classes=["car"])
with pytest.raises(ValueError):
det.transform(dataset)