Skip to content

Commit 969e002

Browse files
committed
feat: implement memory-efficient IoU and NMS with CompactMask integration
- Add `compact_mask_iou_batch` for optimised IoU computation on RLE crops (avoiding full (N, H, W) arrays). - Enhance `mask_iou_batch` and NMS routines to support CompactMask inputs. - Introduce `compact_masks` parameter in `InferenceSlicer` for end-to-end CompactMask handling. - Update docstrings across affected components to reflect CompactMask integration.
1 parent ad6ceb7 commit 969e002

File tree

3 files changed

+186
-6
lines changed

3 files changed

+186
-6
lines changed

src/supervision/detection/compact_mask.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ class CompactMask:
131131
* ``np.asarray(mask)`` → dense ``(N, H, W)`` bool array (numpy interop).
132132
* ``mask.shape``, ``mask.dtype``, ``mask.area`` — match the dense API.
133133
134+
:class:`CompactMask` is **not** a drop-in ``np.ndarray`` replacement.
135+
When you need to call arbitrary ndarray methods (``astype``, ``reshape``,
136+
``ravel``, ``any``, ``all``, …) call :meth:`to_dense` first:
137+
``cm.to_dense().astype(np.uint8)``. :meth:`to_dense` is the single
138+
explicit materialisation boundary.
139+
140+
.. note:: **RLE encoding incompatibility with pycocotools / COCO API**
141+
142+
:class:`CompactMask` uses **row-major (C-order)** run-lengths scoped
143+
to each mask's bounding-box crop. The COCO API (pycocotools) uses
144+
**column-major (Fortran-order)** run-lengths scoped to the **full
145+
image**. The two formats are not interchangeable: you cannot pass a
146+
:class:`CompactMask` RLE directly to ``maskUtils.iou()`` or
147+
``maskUtils.decode()``, and you cannot load a COCO RLE dict into a
148+
:class:`CompactMask` without re-encoding. Use
149+
:meth:`to_dense` to obtain a standard boolean array, then pass it to
150+
pycocotools if needed.
151+
134152
Args:
135153
rles: List of N int32 run-length arrays.
136154
crop_shapes: Array of shape ``(N, 2)`` — ``(crop_h, crop_w)`` per mask.

src/supervision/detection/tools/inference_slicer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ class InferenceSlicer:
8282
overlap_metric (OverlapMetric or str): Metric to compute overlap
8383
(`IOU` or `IOS`).
8484
thread_workers (int): Number of threads for concurrent slice inference.
85+
compact_masks (bool): If ``True``, dense ``(N, H, W)`` boolean mask
86+
arrays returned by the callback are immediately converted to a
87+
:class:`~supervision.detection.compact_mask.CompactMask`. This
88+
keeps masks in run-length-encoded form for the entire pipeline —
89+
merge, NMS, and annotation — avoiding the large ``(N, H, W)``
90+
allocations that cause OOM on high-resolution images with many
91+
objects. IoU and NMS are computed directly on the RLE crops
92+
without ever materialising a full ``(N, H, W)`` array.
93+
Defaults to ``False`` for backward compatibility.
8594
8695
Raises:
8796
ValueError: If `slice_wh` or `overlap_wh` are invalid or inconsistent.
@@ -130,6 +139,7 @@ def __init__(
130139
iou_threshold: float = 0.5,
131140
overlap_metric: OverlapMetric | str = OverlapMetric.IOU,
132141
thread_workers: int = 1,
142+
compact_masks: bool = False,
133143
):
134144
slice_wh_norm = self._normalize_slice_wh(slice_wh)
135145
overlap_wh_norm = self._normalize_overlap_wh(overlap_wh)
@@ -143,6 +153,7 @@ def __init__(
143153
self.overlap_filter = OverlapFilter.from_value(overlap_filter)
144154
self.callback = callback
145155
self.thread_workers = thread_workers
156+
self.compact_masks = compact_masks
146157

147158
def __call__(self, image: ImageType) -> Detections:
148159
"""
@@ -204,8 +215,22 @@ def _run_callback(self, image: ImageType, offset: np.ndarray) -> Detections:
204215
"""
205216
image_slice: ImageType = crop_image(image=image, xyxy=offset)
206217
detections = self.callback(image_slice)
207-
resolution_wh = get_image_resolution_wh(image)
208218

219+
if (
220+
self.compact_masks
221+
and detections.mask is not None
222+
and isinstance(detections.mask, np.ndarray)
223+
):
224+
from supervision.detection.compact_mask import CompactMask
225+
226+
slice_w, slice_h = get_image_resolution_wh(image_slice)
227+
detections.mask = CompactMask.from_dense(
228+
detections.mask,
229+
detections.xyxy,
230+
image_shape=(slice_h, slice_w),
231+
)
232+
233+
resolution_wh = get_image_resolution_wh(image)
209234
detections = move_detections(
210235
detections=detections,
211236
offset=offset[:2],

src/supervision/detection/utils/iou_and_nms.py

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,107 @@ def oriented_box_iou_batch(
398398
return ious
399399

400400

401+
def compact_mask_iou_batch(
402+
masks_true: Any,
403+
masks_detection: Any,
404+
overlap_metric: OverlapMetric = OverlapMetric.IOU,
405+
) -> npt.NDArray[np.floating]:
406+
"""Compute pairwise overlap between two :class:`CompactMask` collections.
407+
408+
Avoids materialising full ``(N, H, W)`` arrays by:
409+
410+
1. Vectorised bounding-box pre-filter — pairs whose boxes do not overlap
411+
get IoU = 0 without any mask decoding.
412+
2. Sub-crop decoding — for overlapping pairs, only the intersection region
413+
of each crop is decoded and compared.
414+
3. Crop caching — each individual crop is decoded at most once even when it
415+
participates in many pairs.
416+
417+
The result is numerically identical to running the dense
418+
:func:`mask_iou_batch` on ``np.asarray(masks_true)`` /
419+
``np.asarray(masks_detection)``.
420+
421+
Args:
422+
masks_true: :class:`~supervision.detection.compact_mask.CompactMask`
423+
holding the ground-truth masks.
424+
masks_detection: :class:`~supervision.detection.compact_mask.CompactMask`
425+
holding the detection masks.
426+
overlap_metric: :class:`OverlapMetric` — ``IOU`` or ``IOS``.
427+
428+
Returns:
429+
Float array of shape ``(N1, N2)`` with pairwise overlap values.
430+
"""
431+
n1: int = len(masks_true)
432+
n2: int = len(masks_detection)
433+
result: npt.NDArray[np.floating] = np.zeros((n1, n2), dtype=float)
434+
435+
if n1 == 0 or n2 == 0:
436+
return result
437+
438+
areas_a: npt.NDArray[np.int64] = masks_true.area
439+
areas_b: npt.NDArray[np.int64] = masks_detection.area
440+
441+
# Inclusive per-mask bounding boxes from stored offsets + crop shapes.
442+
# offsets: (N, 2) → (x1, y1); crop_shapes: (N, 2) → (h, w)
443+
x1a: npt.NDArray[np.int32] = masks_true._offsets[:, 0]
444+
y1a: npt.NDArray[np.int32] = masks_true._offsets[:, 1]
445+
x2a: npt.NDArray[np.int32] = x1a + masks_true._crop_shapes[:, 1] - 1
446+
y2a: npt.NDArray[np.int32] = y1a + masks_true._crop_shapes[:, 0] - 1
447+
448+
x1b: npt.NDArray[np.int32] = masks_detection._offsets[:, 0]
449+
y1b: npt.NDArray[np.int32] = masks_detection._offsets[:, 1]
450+
x2b: npt.NDArray[np.int32] = x1b + masks_detection._crop_shapes[:, 1] - 1
451+
y2b: npt.NDArray[np.int32] = y1b + masks_detection._crop_shapes[:, 0] - 1
452+
453+
# Pairwise intersection bounding box — shape (N1, N2).
454+
ix1: npt.NDArray[np.int32] = np.maximum(x1a[:, None], x1b[None, :])
455+
iy1: npt.NDArray[np.int32] = np.maximum(y1a[:, None], y1b[None, :])
456+
ix2: npt.NDArray[np.int32] = np.minimum(x2a[:, None], x2b[None, :])
457+
iy2: npt.NDArray[np.int32] = np.minimum(y2a[:, None], y2b[None, :])
458+
bbox_overlap: npt.NDArray[np.bool_] = (ix1 <= ix2) & (iy1 <= iy2)
459+
460+
# Decode each crop at most once, even if it participates in many pairs.
461+
crops_a: dict[int, npt.NDArray[np.bool_]] = {}
462+
crops_b: dict[int, npt.NDArray[np.bool_]] = {}
463+
464+
for idx_pair in np.argwhere(bbox_overlap):
465+
i, j = int(idx_pair[0]), int(idx_pair[1])
466+
467+
if i not in crops_a:
468+
crops_a[i] = masks_true.crop(i)
469+
if j not in crops_b:
470+
crops_b[j] = masks_detection.crop(j)
471+
472+
lx1 = int(ix1[i, j])
473+
ly1 = int(iy1[i, j])
474+
lx2 = int(ix2[i, j])
475+
ly2 = int(iy2[i, j])
476+
477+
ox_a, oy_a = int(x1a[i]), int(y1a[i])
478+
sub_a = crops_a[i][ly1 - oy_a : ly2 - oy_a + 1, lx1 - ox_a : lx2 - ox_a + 1]
479+
480+
ox_b, oy_b = int(x1b[j]), int(y1b[j])
481+
sub_b = crops_b[j][ly1 - oy_b : ly2 - oy_b + 1, lx1 - ox_b : lx2 - ox_b + 1]
482+
483+
inter = int(np.logical_and(sub_a, sub_b).sum())
484+
area_a_i = int(areas_a[i])
485+
area_b_j = int(areas_b[j])
486+
487+
if overlap_metric == OverlapMetric.IOU:
488+
union = area_a_i + area_b_j - inter
489+
result[i, j] = inter / union if union > 0 else 0.0
490+
elif overlap_metric == OverlapMetric.IOS:
491+
small = min(area_a_i, area_b_j)
492+
result[i, j] = inter / small if small > 0 else 0.0
493+
else:
494+
raise ValueError(
495+
f"overlap_metric {overlap_metric} is not supported, "
496+
"only 'IOU' and 'IOS' are supported"
497+
)
498+
499+
return result
500+
501+
401502
def _mask_iou_batch_split(
402503
masks_true: npt.NDArray[Any],
403504
masks_detection: npt.NDArray[Any],
@@ -461,16 +562,36 @@ def mask_iou_batch(
461562
Compute Intersection over Union (IoU) of two sets of masks -
462563
`masks_true` and `masks_detection`.
463564
565+
Accepts both dense ``(N, H, W)`` boolean arrays and
566+
:class:`~supervision.detection.compact_mask.CompactMask` objects.
567+
When both inputs are :class:`~supervision.detection.compact_mask.CompactMask`,
568+
the computation uses :func:`compact_mask_iou_batch` to avoid materialising
569+
full ``(N, H, W)`` arrays.
570+
464571
Args:
465-
masks_true (np.ndarray): 3D `np.ndarray` representing ground-truth masks.
466-
masks_detection (np.ndarray): 3D `np.ndarray` representing detection masks.
572+
masks_true (np.ndarray): 3D `np.ndarray` representing ground-truth masks,
573+
or a :class:`~supervision.detection.compact_mask.CompactMask`.
574+
masks_detection (np.ndarray): 3D `np.ndarray` representing detection masks,
575+
or a :class:`~supervision.detection.compact_mask.CompactMask`.
467576
overlap_metric (OverlapMetric): Metric used to compute the degree of overlap
468577
between pairs of masks (e.g., IoU, IoS).
469578
memory_limit (int): memory limit in MB, default is 1024 * 5 MB (5GB).
579+
Ignored when both inputs are CompactMask.
470580
471581
Returns:
472582
np.ndarray: Pairwise IoU of masks from `masks_true` and `masks_detection`.
473583
"""
584+
from supervision.detection.compact_mask import CompactMask
585+
586+
if isinstance(masks_true, CompactMask) and isinstance(masks_detection, CompactMask):
587+
return compact_mask_iou_batch(masks_true, masks_detection, overlap_metric)
588+
589+
# Materialise any CompactMask that was passed alongside a dense array.
590+
if isinstance(masks_true, CompactMask):
591+
masks_true = np.asarray(masks_true)
592+
if isinstance(masks_detection, CompactMask):
593+
masks_detection = np.asarray(masks_detection)
594+
474595
memory = (
475596
masks_true.shape[0]
476597
* masks_true.shape[1]
@@ -546,11 +667,18 @@ def mask_non_max_suppression(
546667
if columns == 5:
547668
predictions = np.c_[predictions, np.zeros(rows)]
548669

670+
from supervision.detection.compact_mask import CompactMask
671+
549672
sort_index = predictions[:, 4].argsort()[::-1]
550673
predictions = predictions[sort_index]
551674
masks = masks[sort_index]
552-
masks_resized = resize_masks(masks, mask_dimension)
553-
ious = mask_iou_batch(masks_resized, masks_resized, overlap_metric)
675+
676+
if isinstance(masks, CompactMask):
677+
# CompactMask IoU is computed directly on RLE crops — no resize needed.
678+
ious = compact_mask_iou_batch(masks, masks, overlap_metric)
679+
else:
680+
masks_resized = resize_masks(masks, mask_dimension)
681+
ious = mask_iou_batch(masks_resized, masks_resized, overlap_metric)
554682
categories = predictions[:, 5]
555683

556684
keep = np.ones(rows, dtype=bool)
@@ -710,7 +838,16 @@ def mask_non_max_merge(
710838
AssertionError: If `iou_threshold` is not within the closed
711839
range from `0` to `1`.
712840
"""
713-
masks_resized = resize_masks(masks, mask_dimension)
841+
from supervision.detection.compact_mask import CompactMask
842+
843+
if isinstance(masks, CompactMask):
844+
# _group_overlapping_masks needs dense arrays for logical_or union merging;
845+
# materialise to a downscaled dense array to keep memory reasonable.
846+
masks = resize_masks(np.asarray(masks), mask_dimension)
847+
else:
848+
masks = resize_masks(masks, mask_dimension)
849+
masks_resized = masks
850+
714851
if predictions.shape[1] == 5:
715852
return _group_overlapping_masks(
716853
predictions, masks_resized, iou_threshold, overlap_metric

0 commit comments

Comments
 (0)