Skip to content

Commit 8940dc3

Browse files
Zhitao Yufacebook-github-bot
authored andcommitted
Fix SanitizeBoundingBoxes Handling of Semantic Masks (#9256)
Summary: Background Currently, torchvision.transforms.v2.SanitizeBoundingBoxes fails when used inside a v2.Compose that receives both bounding boxes and a semantic segmentation mask as inputs. The transform attempts to apply a per-box boolean validity mask to all tv_tensors.Mask objects, including semantic masks (shape [H, W]), resulting in a shape mismatch and a crash. Error Example: IndexError: The shape of the mask [3] at index 0 does not match the shape of the indexed tensor [1080, 1920] at index 0 Expected Behavior The transform should only sanitize masks that have a 1:1 mapping with bounding boxes (e.g., per-instance masks). Semantic masks (2D, shape [H, W]) should be passed through unchanged. Task Objectives Update SanitizeBoundingBoxes Logic: Detect whether a tv_tensors.Mask is a per-instance mask (shape [N, H, W] or [N, ...] where N == num_boxes) or a semantic mask (shape [H, W]). Only apply the per-box validity mask to per-instance masks. Pass through semantic masks unchanged. If a mask does not match the number of boxes, do not raise an error; instead, pass it through. Optionally, log a warning if a mask is skipped for sanitization due to shape mismatch. Clarify Documentation: Update the docstring for SanitizeBoundingBoxes to explicitly state: Only per-instance masks are sanitized. Semantic masks are passed through unchanged. The transform does not require users to pass masks to labels_getter for them to be sanitized. Add/Update Unit Tests: Test with both per-instance masks and semantic masks in a v2.Compose. Ensure semantic masks are not sanitized and do not cause errors. Ensure per-instance masks are sanitized correctly. This can be added in TestSanitizeBoundingBoxes Backward Compatibility: Ensure that the change does not break existing datasets or user code that relies on current behavior. Finally submit a PR with the changes and link the issue in the description. Differential Revision: D85840801
1 parent ca22124 commit 8940dc3

File tree

2 files changed

+90
-3
lines changed

2 files changed

+90
-3
lines changed

test/test_transforms_v2.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7348,6 +7348,82 @@ def test_no_label(self):
73487348
assert isinstance(out_img, tv_tensors.Image)
73497349
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
73507350

7351+
def test_semantic_masks_passthrough(self):
7352+
# Test that semantic masks (2D) pass through unchanged
7353+
H, W = 256, 128
7354+
boxes = tv_tensors.BoundingBoxes(
7355+
[[0, 0, 50, 50], [60, 60, 100, 100]],
7356+
format=tv_tensors.BoundingBoxFormat.XYXY,
7357+
canvas_size=(H, W),
7358+
)
7359+
7360+
# Create semantic segmentation mask (H, W) - should NOT be sanitized
7361+
semantic_mask = tv_tensors.Mask(torch.randint(0, 10, size=(H, W)))
7362+
7363+
sample = {
7364+
"boxes": boxes,
7365+
"semantic_mask": semantic_mask,
7366+
}
7367+
7368+
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)
7369+
7370+
# Check that semantic mask passed through unchanged
7371+
assert isinstance(out["semantic_mask"], tv_tensors.Mask)
7372+
assert out["semantic_mask"].shape == (H, W)
7373+
assert_equal(out["semantic_mask"], semantic_mask)
7374+
7375+
def test_masks_with_mismatched_shape_passthrough(self):
7376+
# Test that masks with shapes that don't match the number of boxes are passed through
7377+
H, W = 256, 128
7378+
boxes = tv_tensors.BoundingBoxes(
7379+
[[0, 0, 10, 10], [20, 20, 30, 30], [50, 50, 60, 60]],
7380+
format=tv_tensors.BoundingBoxFormat.XYXY,
7381+
canvas_size=(H, W),
7382+
)
7383+
7384+
# Create masks with different number of instances than boxes
7385+
mismatched_masks = tv_tensors.Mask(torch.randint(0, 2, size=(5, H, W))) # 5 masks but 3 boxes
7386+
7387+
sample = {
7388+
"boxes": boxes,
7389+
"masks": mismatched_masks,
7390+
}
7391+
7392+
# Should not raise an error, masks should pass through unchanged
7393+
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)
7394+
7395+
assert isinstance(out["masks"], tv_tensors.Mask)
7396+
assert out["masks"].shape == (5, H, W)
7397+
assert_equal(out["masks"], mismatched_masks)
7398+
7399+
def test_per_instance_masks_sanitized(self):
7400+
# Test that per-instance masks (N, H, W) are correctly sanitized
7401+
H, W = 256, 128
7402+
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=10, min_area=10)
7403+
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
7404+
num_boxes = boxes.shape[0]
7405+
7406+
# Create per-instance masks matching the number of boxes
7407+
per_instance_masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W)))
7408+
labels = torch.arange(num_boxes)
7409+
7410+
sample = {
7411+
"boxes": boxes,
7412+
"masks": per_instance_masks,
7413+
"labels": labels,
7414+
}
7415+
7416+
out = transforms.SanitizeBoundingBoxes(min_size=10, min_area=10)(sample)
7417+
7418+
# Check that masks were sanitized correctly
7419+
assert isinstance(out["masks"], tv_tensors.Mask)
7420+
assert out["masks"].shape[0] == len(valid_indices)
7421+
assert out["masks"].shape[0] == out["boxes"].shape[0] == out["labels"].shape[0]
7422+
7423+
# Verify correct masks were kept
7424+
for i, valid_idx in enumerate(valid_indices):
7425+
assert_equal(out["masks"][i], per_instance_masks[valid_idx])
7426+
73517427
def test_errors_transform(self):
73527428
good_bbox = tv_tensors.BoundingBoxes(
73537429
[[0, 0, 10, 10]],

torchvision/transforms/v2/_misc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ class SanitizeBoundingBoxes(Transform):
369369
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
370370
(see ``labels_getter`` parameter).
371371
372+
.. note::
373+
**Mask handling**: This transform automatically detects and sanitizes per-instance masks
374+
(shape ``[N, H, W]`` where N matches the number of bounding boxes). Semantic segmentation masks
375+
(shape ``[H, W]``) or masks with mismatched dimensions are passed through unchanged.
376+
You do not need to add masks to ``labels_getter`` for them to be sanitized.
377+
372378
It is recommended to call it at the end of a pipeline, before passing the
373379
input to the models. It is critical to call this transform if
374380
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
@@ -456,12 +462,17 @@ def forward(self, *inputs: Any) -> Any:
456462

457463
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
458464
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
459-
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
465+
is_bounding_boxes = isinstance(inpt, tv_tensors.BoundingBoxes)
466+
is_mask = isinstance(inpt, tv_tensors.Mask)
460467

461-
if not (is_label or is_bounding_boxes_or_mask):
468+
if not (is_label or is_bounding_boxes or is_mask):
462469
return inpt
463470

464-
output = inpt[params["valid"]]
471+
try:
472+
output = inpt[params["valid"]]
473+
except (IndexError):
474+
# If indexing fails (e.g., shape mismatch), pass through unchanged
475+
return inpt
465476

466477
if is_label:
467478
return output

0 commit comments

Comments
 (0)