@@ -398,6 +398,107 @@ def oriented_box_iou_batch(
398398 return ious
399399
400400
401+ def compact_mask_iou_batch (
402+ masks_true : Any ,
403+ masks_detection : Any ,
404+ overlap_metric : OverlapMetric = OverlapMetric .IOU ,
405+ ) -> npt .NDArray [np .floating ]:
406+ """Compute pairwise overlap between two :class:`CompactMask` collections.
407+
408+ Avoids materialising full ``(N, H, W)`` arrays by:
409+
410+ 1. Vectorised bounding-box pre-filter — pairs whose boxes do not overlap
411+ get IoU = 0 without any mask decoding.
412+ 2. Sub-crop decoding — for overlapping pairs, only the intersection region
413+ of each crop is decoded and compared.
414+ 3. Crop caching — each individual crop is decoded at most once even when it
415+ participates in many pairs.
416+
417+ The result is numerically identical to running the dense
418+ :func:`mask_iou_batch` on ``np.asarray(masks_true)`` /
419+ ``np.asarray(masks_detection)``.
420+
421+ Args:
422+ masks_true: :class:`~supervision.detection.compact_mask.CompactMask`
423+ holding the ground-truth masks.
424+ masks_detection: :class:`~supervision.detection.compact_mask.CompactMask`
425+ holding the detection masks.
426+ overlap_metric: :class:`OverlapMetric` — ``IOU`` or ``IOS``.
427+
428+ Returns:
429+ Float array of shape ``(N1, N2)`` with pairwise overlap values.
430+ """
431+ n1 : int = len (masks_true )
432+ n2 : int = len (masks_detection )
433+ result : npt .NDArray [np .floating ] = np .zeros ((n1 , n2 ), dtype = float )
434+
435+ if n1 == 0 or n2 == 0 :
436+ return result
437+
438+ areas_a : npt .NDArray [np .int64 ] = masks_true .area
439+ areas_b : npt .NDArray [np .int64 ] = masks_detection .area
440+
441+ # Inclusive per-mask bounding boxes from stored offsets + crop shapes.
442+ # offsets: (N, 2) → (x1, y1); crop_shapes: (N, 2) → (h, w)
443+ x1a : npt .NDArray [np .int32 ] = masks_true ._offsets [:, 0 ]
444+ y1a : npt .NDArray [np .int32 ] = masks_true ._offsets [:, 1 ]
445+ x2a : npt .NDArray [np .int32 ] = x1a + masks_true ._crop_shapes [:, 1 ] - 1
446+ y2a : npt .NDArray [np .int32 ] = y1a + masks_true ._crop_shapes [:, 0 ] - 1
447+
448+ x1b : npt .NDArray [np .int32 ] = masks_detection ._offsets [:, 0 ]
449+ y1b : npt .NDArray [np .int32 ] = masks_detection ._offsets [:, 1 ]
450+ x2b : npt .NDArray [np .int32 ] = x1b + masks_detection ._crop_shapes [:, 1 ] - 1
451+ y2b : npt .NDArray [np .int32 ] = y1b + masks_detection ._crop_shapes [:, 0 ] - 1
452+
453+ # Pairwise intersection bounding box — shape (N1, N2).
454+ ix1 : npt .NDArray [np .int32 ] = np .maximum (x1a [:, None ], x1b [None , :])
455+ iy1 : npt .NDArray [np .int32 ] = np .maximum (y1a [:, None ], y1b [None , :])
456+ ix2 : npt .NDArray [np .int32 ] = np .minimum (x2a [:, None ], x2b [None , :])
457+ iy2 : npt .NDArray [np .int32 ] = np .minimum (y2a [:, None ], y2b [None , :])
458+ bbox_overlap : npt .NDArray [np .bool_ ] = (ix1 <= ix2 ) & (iy1 <= iy2 )
459+
460+ # Decode each crop at most once, even if it participates in many pairs.
461+ crops_a : dict [int , npt .NDArray [np .bool_ ]] = {}
462+ crops_b : dict [int , npt .NDArray [np .bool_ ]] = {}
463+
464+ for idx_pair in np .argwhere (bbox_overlap ):
465+ i , j = int (idx_pair [0 ]), int (idx_pair [1 ])
466+
467+ if i not in crops_a :
468+ crops_a [i ] = masks_true .crop (i )
469+ if j not in crops_b :
470+ crops_b [j ] = masks_detection .crop (j )
471+
472+ lx1 = int (ix1 [i , j ])
473+ ly1 = int (iy1 [i , j ])
474+ lx2 = int (ix2 [i , j ])
475+ ly2 = int (iy2 [i , j ])
476+
477+ ox_a , oy_a = int (x1a [i ]), int (y1a [i ])
478+ sub_a = crops_a [i ][ly1 - oy_a : ly2 - oy_a + 1 , lx1 - ox_a : lx2 - ox_a + 1 ]
479+
480+ ox_b , oy_b = int (x1b [j ]), int (y1b [j ])
481+ sub_b = crops_b [j ][ly1 - oy_b : ly2 - oy_b + 1 , lx1 - ox_b : lx2 - ox_b + 1 ]
482+
483+ inter = int (np .logical_and (sub_a , sub_b ).sum ())
484+ area_a_i = int (areas_a [i ])
485+ area_b_j = int (areas_b [j ])
486+
487+ if overlap_metric == OverlapMetric .IOU :
488+ union = area_a_i + area_b_j - inter
489+ result [i , j ] = inter / union if union > 0 else 0.0
490+ elif overlap_metric == OverlapMetric .IOS :
491+ small = min (area_a_i , area_b_j )
492+ result [i , j ] = inter / small if small > 0 else 0.0
493+ else :
494+ raise ValueError (
495+ f"overlap_metric { overlap_metric } is not supported, "
496+ "only 'IOU' and 'IOS' are supported"
497+ )
498+
499+ return result
500+
501+
401502def _mask_iou_batch_split (
402503 masks_true : npt .NDArray [Any ],
403504 masks_detection : npt .NDArray [Any ],
@@ -461,16 +562,36 @@ def mask_iou_batch(
461562 Compute Intersection over Union (IoU) of two sets of masks -
462563 `masks_true` and `masks_detection`.
463564
565+ Accepts both dense ``(N, H, W)`` boolean arrays and
566+ :class:`~supervision.detection.compact_mask.CompactMask` objects.
567+ When both inputs are :class:`~supervision.detection.compact_mask.CompactMask`,
568+ the computation uses :func:`compact_mask_iou_batch` to avoid materialising
569+ full ``(N, H, W)`` arrays.
570+
464571 Args:
465- masks_true (np.ndarray): 3D `np.ndarray` representing ground-truth masks.
466- masks_detection (np.ndarray): 3D `np.ndarray` representing detection masks.
572+ masks_true (np.ndarray): 3D `np.ndarray` representing ground-truth masks,
573+ or a :class:`~supervision.detection.compact_mask.CompactMask`.
574+ masks_detection (np.ndarray): 3D `np.ndarray` representing detection masks,
575+ or a :class:`~supervision.detection.compact_mask.CompactMask`.
467576 overlap_metric (OverlapMetric): Metric used to compute the degree of overlap
468577 between pairs of masks (e.g., IoU, IoS).
469578 memory_limit (int): memory limit in MB, default is 1024 * 5 MB (5GB).
579+ Ignored when both inputs are CompactMask.
470580
471581 Returns:
472582 np.ndarray: Pairwise IoU of masks from `masks_true` and `masks_detection`.
473583 """
584+ from supervision .detection .compact_mask import CompactMask
585+
586+ if isinstance (masks_true , CompactMask ) and isinstance (masks_detection , CompactMask ):
587+ return compact_mask_iou_batch (masks_true , masks_detection , overlap_metric )
588+
589+ # Materialise any CompactMask that was passed alongside a dense array.
590+ if isinstance (masks_true , CompactMask ):
591+ masks_true = np .asarray (masks_true )
592+ if isinstance (masks_detection , CompactMask ):
593+ masks_detection = np .asarray (masks_detection )
594+
474595 memory = (
475596 masks_true .shape [0 ]
476597 * masks_true .shape [1 ]
@@ -546,11 +667,18 @@ def mask_non_max_suppression(
546667 if columns == 5 :
547668 predictions = np .c_ [predictions , np .zeros (rows )]
548669
670+ from supervision .detection .compact_mask import CompactMask
671+
549672 sort_index = predictions [:, 4 ].argsort ()[::- 1 ]
550673 predictions = predictions [sort_index ]
551674 masks = masks [sort_index ]
552- masks_resized = resize_masks (masks , mask_dimension )
553- ious = mask_iou_batch (masks_resized , masks_resized , overlap_metric )
675+
676+ if isinstance (masks , CompactMask ):
677+ # CompactMask IoU is computed directly on RLE crops — no resize needed.
678+ ious = compact_mask_iou_batch (masks , masks , overlap_metric )
679+ else :
680+ masks_resized = resize_masks (masks , mask_dimension )
681+ ious = mask_iou_batch (masks_resized , masks_resized , overlap_metric )
554682 categories = predictions [:, 5 ]
555683
556684 keep = np .ones (rows , dtype = bool )
@@ -710,7 +838,16 @@ def mask_non_max_merge(
710838 AssertionError: If `iou_threshold` is not within the closed
711839 range from `0` to `1`.
712840 """
713- masks_resized = resize_masks (masks , mask_dimension )
841+ from supervision .detection .compact_mask import CompactMask
842+
843+ if isinstance (masks , CompactMask ):
844+ # _group_overlapping_masks needs dense arrays for logical_or union merging;
845+ # materialise to a downscaled dense array to keep memory reasonable.
846+ masks = resize_masks (np .asarray (masks ), mask_dimension )
847+ else :
848+ masks = resize_masks (masks , mask_dimension )
849+ masks_resized = masks
850+
714851 if predictions .shape [1 ] == 5 :
715852 return _group_overlapping_masks (
716853 predictions , masks_resized , iou_threshold , overlap_metric
0 commit comments