Skip to content

Commit 7d3674e

Browse files
aymuos15ericspod
andauthored
Add 3D support and confusion matrix output to PanopticQualityMetric (#8684)
Fixes #5702 ### Description This PR adds two features requested in issue #5702: 1. **3D Support**: `PanopticQualityMetric` now accepts both 4D tensors (B2HW for 2D data) and 5D tensors (B2HWD for 3D data). Previously, only 2D inputs were supported. 2. **Confusion Matrix Output**: Added `return_confusion_matrix` parameter to `PanopticQualityMetric`. When set to `True`, the `aggregate()` method returns raw confusion matrix values (tp, fp, fn, iou_sum) instead of computed metrics, enabling custom metric calculations. 3. **Helper Function**: Added `compute_mean_iou()` function to compute mean IoU from confusion matrix values. **Note**: While [panoptica](https://github.com/BrainLesion/panoptica) exists as a standalone library, I feel this would still be a nice addition to MONAI. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 14f5b80 commit 7d3674e

File tree

2 files changed

+165
-7
lines changed

2 files changed

+165
-7
lines changed

monai/metrics/panoptic_quality.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment")
2323

24-
__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"]
24+
__all__ = ["PanopticQualityMetric", "compute_panoptic_quality", "compute_mean_iou"]
2525

2626

2727
class PanopticQualityMetric(CumulativeIterationMetric):
@@ -55,6 +55,8 @@ class PanopticQualityMetric(CumulativeIterationMetric):
5555
If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
5656
maximal amount of unique pairing.
5757
smooth_numerator: a small constant added to the numerator to avoid zero.
58+
return_confusion_matrix: if True, returns raw confusion matrix values (tp, fp, fn, iou_sum)
59+
instead of computed metrics. Default is False.
5860
5961
"""
6062

@@ -65,19 +67,22 @@ def __init__(
6567
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
6668
match_iou_threshold: float = 0.5,
6769
smooth_numerator: float = 1e-6,
70+
return_confusion_matrix: bool = False,
6871
) -> None:
6972
super().__init__()
7073
self.num_classes = num_classes
7174
self.reduction = reduction
7275
self.match_iou_threshold = match_iou_threshold
7376
self.smooth_numerator = smooth_numerator
7477
self.metric_name = ensure_tuple(metric_name)
78+
self.return_confusion_matrix = return_confusion_matrix
7579

7680
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
7781
"""
7882
Args:
79-
y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
80-
second channel represent the instance predictions and classification predictions respectively.
83+
y_pred: Predictions. It must be in the form of B2HW (2D) or B2HWD (3D) and have integer type.
84+
The first channel and the second channel represent the instance predictions and classification
85+
predictions respectively.
8186
y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
8287
second channel represent the instance labels and classification labels respectively.
8388
Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
@@ -86,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
8691
Raises:
8792
ValueError: when `y_pred` and `y` have different shapes.
8893
ValueError: when `y_pred` and `y` have != 2 channels.
89-
ValueError: when `y_pred` and `y` have != 4 dimensions.
94+
ValueError: when `y_pred` and `y` have != 4 or 5 dimensions.
9095
9196
"""
9297
if y_pred.shape != y.shape:
@@ -98,8 +103,10 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
98103
)
99104

100105
dims = y_pred.ndimension()
101-
if dims != 4:
102-
raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.")
106+
if dims not in (4, 5):
107+
raise ValueError(
108+
f"y_pred should have 4 dimensions (batch, 2, h, w) or 5 dimensions (batch, 2, h, w, d), got {dims}."
109+
)
103110

104111
batch_size = y_pred.shape[0]
105112

@@ -131,13 +138,22 @@ def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Ten
131138
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
132139
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
133140
141+
Returns:
142+
If `return_confusion_matrix` is True, returns the raw confusion matrix [tp, fp, fn, iou_sum].
143+
Otherwise, returns the computed metric(s) based on `metric_name`.
144+
134145
"""
135146
data = self.get_buffer()
136147
if not isinstance(data, torch.Tensor):
137148
raise ValueError("the data to aggregate must be PyTorch Tensor.")
138149

139150
# do metric reduction
140151
f, _ = do_metric_reduction(data, reduction or self.reduction)
152+
153+
if self.return_confusion_matrix:
154+
# Return raw confusion matrix values
155+
return f
156+
141157
tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3]
142158
results = []
143159
for metric_name in self.metric_name:
@@ -169,7 +185,7 @@ def compute_panoptic_quality(
169185
calculate PQ, and returning them directly enables further calculation over all images.
170186
171187
Args:
172-
pred: input data to compute, it must be in the form of HW and have integer type.
188+
pred: input data to compute, it must be in the form of HW (2D) or HWD (3D) and have integer type.
173189
gt: ground truth. It must have the same shape as `pred` and have integer type.
174190
metric_name: output metric. The value can be "pq", "sq" or "rq".
175191
remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
@@ -294,3 +310,24 @@ def _check_panoptic_metric_name(metric_name: str) -> str:
294310
if metric_name in ["recognition_quality", "rq"]:
295311
return "rq"
296312
raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.")
313+
314+
315+
def compute_mean_iou(confusion_matrix: torch.Tensor, smooth_numerator: float = 1e-6) -> torch.Tensor:
316+
"""Compute mean IoU from confusion matrix values.
317+
318+
Args:
319+
confusion_matrix: tensor with shape (..., 4) where the last dimension contains
320+
[tp, fp, fn, iou_sum] as returned by `compute_panoptic_quality` with `output_confusion_matrix=True`.
321+
smooth_numerator: a small constant added to the numerator to avoid zero.
322+
323+
Returns:
324+
Mean IoU computed as iou_sum / (tp + smooth_numerator).
325+
326+
"""
327+
if confusion_matrix.shape[-1] != 4:
328+
raise ValueError(
329+
f"confusion_matrix should have shape (..., 4) with [tp, fp, fn, iou_sum], "
330+
f"got shape {confusion_matrix.shape}."
331+
)
332+
tp, iou_sum = confusion_matrix[..., 0], confusion_matrix[..., 3]
333+
return iou_sum / (tp + smooth_numerator)

tests/metrics/test_compute_panoptic_quality.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from parameterized import parameterized
1919

2020
from monai.metrics import PanopticQualityMetric, compute_panoptic_quality
21+
from monai.metrics.panoptic_quality import compute_mean_iou
2122
from tests.test_utils import SkipIfNoModule
2223

2324
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -88,6 +89,27 @@
8889
[torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])],
8990
]
9091

92+
# 3D test cases
93+
sample_3d_pred = torch.as_tensor(
94+
[[[[[2, 0], [1, 1]], [[0, 1], [2, 1]]], [[[0, 1], [3, 0]], [[1, 0], [1, 1]]]]], # instance channel # class channel
95+
device=_device,
96+
)
97+
98+
sample_3d_gt = torch.as_tensor(
99+
[[[[[2, 0], [0, 0]], [[2, 2], [2, 3]]], [[[3, 3], [3, 2]], [[2, 2], [3, 3]]]]], # instance channel # class channel
100+
device=_device,
101+
)
102+
103+
# test 3D sample, num_classes = 3, match_iou_threshold = 0.5
104+
TEST_3D_CASE_1 = [{"num_classes": 3, "match_iou_threshold": 0.5}, sample_3d_pred, sample_3d_gt]
105+
106+
# test confusion matrix return
107+
TEST_CM_CASE_1 = [
108+
{"num_classes": 3, "match_iou_threshold": 0.5, "return_confusion_matrix": True},
109+
sample_3_pred,
110+
sample_3_gt,
111+
]
112+
91113

92114
@SkipIfNoModule("scipy.optimize")
93115
class TestPanopticQualityMetric(unittest.TestCase):
@@ -108,6 +130,105 @@ def test_value_class(self, input_params, y_pred, y_gt, expected_value):
108130
else:
109131
np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4)
110132

133+
def test_3d_support(self):
134+
"""Test that 3D input is properly supported."""
135+
input_params, y_pred, y_gt = TEST_3D_CASE_1
136+
metric = PanopticQualityMetric(**input_params)
137+
# Should not raise an error for 3D input
138+
metric(y_pred, y_gt)
139+
outputs = metric.aggregate()
140+
# Check that output is a tensor
141+
self.assertIsInstance(outputs, torch.Tensor)
142+
# Check that output shape is correct (num_classes,)
143+
self.assertEqual(outputs.shape, torch.Size([3]))
144+
145+
def test_confusion_matrix_return(self):
146+
"""Test that confusion matrix can be returned instead of computed metrics."""
147+
input_params, y_pred, y_gt = TEST_CM_CASE_1
148+
metric = PanopticQualityMetric(**input_params)
149+
metric(y_pred, y_gt)
150+
outputs = metric.aggregate()
151+
# Check that output is a tensor with shape (batch_size, num_classes, 4)
152+
self.assertIsInstance(outputs, torch.Tensor)
153+
self.assertEqual(outputs.shape[-1], 4)
154+
# Verify that values correspond to [tp, fp, fn, iou_sum]
155+
tp, fp, fn, iou_sum = outputs[..., 0], outputs[..., 1], outputs[..., 2], outputs[..., 3]
156+
# tp, fp, fn should be non-negative integers
157+
self.assertTrue(torch.all(tp >= 0))
158+
self.assertTrue(torch.all(fp >= 0))
159+
self.assertTrue(torch.all(fn >= 0))
160+
# iou_sum should be non-negative float
161+
self.assertTrue(torch.all(iou_sum >= 0))
162+
163+
def test_compute_mean_iou(self):
164+
"""Test mean IoU computation from confusion matrix."""
165+
input_params, y_pred, y_gt = TEST_CM_CASE_1
166+
metric = PanopticQualityMetric(**input_params)
167+
metric(y_pred, y_gt)
168+
confusion_matrix = metric.aggregate()
169+
mean_iou = compute_mean_iou(confusion_matrix)
170+
171+
# Check shape is correct
172+
self.assertEqual(mean_iou.shape, confusion_matrix.shape[:-1])
173+
174+
# Check values are non-negative
175+
self.assertTrue(torch.all(mean_iou >= 0))
176+
177+
# Validate against expected values
178+
# mean_iou = iou_sum / (tp + smooth_numerator)
179+
tp = confusion_matrix[..., 0]
180+
iou_sum = confusion_matrix[..., 3]
181+
expected_mean_iou = iou_sum / (tp + 1e-6) # smooth_numerator=1e-6 is default
182+
np.testing.assert_allclose(mean_iou.cpu().numpy(), expected_mean_iou.cpu().numpy(), atol=1e-4)
183+
184+
def test_metric_name_filtering(self):
185+
"""Test that metric_name parameter properly filters output."""
186+
# Test single metric "sq"
187+
metric_sq = PanopticQualityMetric(num_classes=3, metric_name="sq", match_iou_threshold=0.5)
188+
metric_sq(sample_3_pred, sample_3_gt)
189+
result_sq = metric_sq.aggregate()
190+
self.assertIsInstance(result_sq, torch.Tensor)
191+
self.assertEqual(result_sq.shape, torch.Size([3]))
192+
193+
# Test single metric "rq"
194+
metric_rq = PanopticQualityMetric(num_classes=3, metric_name="rq", match_iou_threshold=0.5)
195+
metric_rq(sample_3_pred, sample_3_gt)
196+
result_rq = metric_rq.aggregate()
197+
self.assertIsInstance(result_rq, torch.Tensor)
198+
self.assertEqual(result_rq.shape, torch.Size([3]))
199+
200+
# Results should be different for different metrics
201+
self.assertFalse(torch.allclose(result_sq, result_rq, atol=1e-4))
202+
203+
def test_invalid_3d_shape(self):
204+
"""Test that invalid 3D shapes are rejected."""
205+
# Shape with 3 dimensions should fail
206+
invalid_pred = torch.randint(0, 5, (2, 2, 10))
207+
invalid_gt = torch.randint(0, 5, (2, 2, 10))
208+
metric = PanopticQualityMetric(num_classes=3)
209+
with self.assertRaises(ValueError):
210+
metric(invalid_pred, invalid_gt)
211+
212+
# Shape with 6 dimensions should fail
213+
invalid_pred = torch.randint(0, 5, (1, 2, 8, 8, 8, 8))
214+
invalid_gt = torch.randint(0, 5, (1, 2, 8, 8, 8, 8))
215+
with self.assertRaises(ValueError):
216+
metric(invalid_pred, invalid_gt)
217+
218+
def test_compute_mean_iou_invalid_shape(self):
219+
"""Test that compute_mean_iou raises ValueError for invalid shapes."""
220+
from monai.metrics.panoptic_quality import compute_mean_iou
221+
222+
# Shape (..., 3) instead of (..., 4) should fail
223+
invalid_confusion_matrix = torch.zeros(3, 3)
224+
with self.assertRaises(ValueError):
225+
compute_mean_iou(invalid_confusion_matrix)
226+
227+
# Shape (..., 5) should also fail
228+
invalid_confusion_matrix = torch.zeros(2, 5)
229+
with self.assertRaises(ValueError):
230+
compute_mean_iou(invalid_confusion_matrix)
231+
111232

112233
if __name__ == "__main__":
113234
unittest.main()

0 commit comments

Comments
 (0)