From 28f99d304533e89e21862abab41f880d2e4d7618 Mon Sep 17 00:00:00 2001 From: Shawn Ray Date: Thu, 5 Mar 2026 15:03:23 -0800 Subject: [PATCH 1/2] Spice Dice Loss Began implementation of functions for sparse dice loss, which is dice loss computed on sparse datasets. initial commit is to track issue https://github.com/Project-MONAI/MONAI/issues/8731 --- monai/losses/dice.py | 200 +++++++++++++++++++++++ tests/losses/test_sparse_dice_loss.py | 227 ++++++++++++++++++++++++++ 2 files changed, 427 insertions(+) create mode 100644 tests/losses/test_sparse_dice_loss.py diff --git a/monai/losses/dice.py b/monai/losses/dice.py index cd76ec1323..71a4e12153 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1111,8 +1111,208 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_gdl * gdl_loss + self.lambda_focal * focal_loss return total_loss + +class SparseDiceLoss(_Loss): + """ + Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. + The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of + the inter-over-union calculation to smooth results respectively, these values should be small. + + The original papers: + + Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation. 3DV 2016. + + Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with + Soft Labels. NeurIPS 2023. + + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. + + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + squared_pred: bool = False, + jaccard: bool = False, + reduction: LossReduction | str = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, + weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, + ) -> None: + """ + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + squared_pred: use squared versions of targets and predictions in the denominator or not. + jaccard: compute Jaccard Index (soft IoU) instead of dice or not. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a Dice loss value is computed independently from each item in the batch + before any `reduction`. + weight: weights to apply to the voxels of each class. If None no weights are applied. + The input can be a single value (same weight for all classes), a sequence of values (the length + of the sequence should be the same as the number of classes. If not ``include_background``, + the number of classes should not include the background category class 0). + The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + + """ + super().__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.squared_pred = squared_pred + self.jaccard = jaccard + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + self.batch = batch + weight = torch.as_tensor(weight) if weight is not None else None + self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor + self.soft_label = soft_label + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + + Raises: + AssertionError: When input and target (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + Example: + >>> from monai.losses.dice import * # NOQA + >>> import torch + >>> from monai.losses.dice import DiceLoss + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = DiceLoss(reduction='none') + >>> loss = self(input, target) + >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape + """ + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + if self.batch: + # reducing spatial dimensions and batch + reduce_axis = [0] + reduce_axis + + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr + + f: torch.Tensor = 1 - numerator / denominator + + num_of_classes = target.shape[1] + if self.class_weight is not None and num_of_classes != 1: + # make sure the lengths of weights are equal to the number of classes + if self.class_weight.ndim == 0: + self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) + else: + if self.class_weight.shape[0] != num_of_classes: + raise ValueError( + """the length of the `weight` sequence should be the same as the number of classes. + If `include_background=False`, the weight should not include + the background category class 0.""" + ) + if self.class_weight.min() < 0: + raise ValueError("the value/values of the `weight` should be no less than 0.") + # apply class_weight to loss + f = f * self.class_weight.to(f) + + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) + f = f.view(broadcast_shape) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return f +sparse_dice_loss = SparseDiceLoss Dice = DiceLoss dice_ce = DiceCELoss dice_focal = DiceFocalLoss diff --git a/tests/losses/test_sparse_dice_loss.py b/tests/losses/test_sparse_dice_loss.py new file mode 100644 index 0000000000..79eb883252 --- /dev/null +++ b/tests/losses/test_sparse_dice_loss.py @@ -0,0 +1,227 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import DiceLoss +from tests.test_utils import test_script_save + +TEST_CASES = [ + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.307576, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.416657, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.0, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 0.435050, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + { + "include_background": True, + "to_onehot_y": True, + "sigmoid": True, + "reduction": "none", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + [[[0.296529], [0.415136]], [[0.599976], [0.428559]]], + ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 0.383713, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "reduction": "sum", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 1.534853, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.307576, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True, "squared_pred": True}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.178337, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True, "jaccard": True}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.470451, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.999963, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + { + "include_background": True, + "to_onehot_y": True, + "other_act": lambda x: torch.log_softmax(x, dim=1), + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + -8.522593, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.774718, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 0, "smooth_dr": 1e-4, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.774733, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 0, "smooth_dr": 1e-4, "batch": False}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.840058, + ], + [ # shape: (2, 2, 3), (2, 1, 3) weight + { + "include_background": True, + "to_onehot_y": True, + "other_act": lambda x: torch.log_softmax(x, dim=1), + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + "weight": (0, 1), + }, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + -8.268515, + ], +] + + +class TestSparseDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiceLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiceLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6))) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + DiceLoss(sigmoid=True, softmax=True) + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertRaisesRegex(ValueError, ""): + DiceLoss(reduction="unknown")(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + DiceLoss(reduction=None)(chn_input, chn_target) + + def test_input_warnings(self): + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertWarns(Warning): + loss = DiceLoss(include_background=False) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = DiceLoss(softmax=True) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = DiceLoss(to_onehot_y=True) + loss.forward(chn_input, chn_target) + + def test_script(self): + loss = DiceLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + + +if __name__ == "__main__": + unittest.main() From 71f692cf316d7c09d2c8d38268cf538fa9ecaa2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 23:07:55 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 71a4e12153..3406496564 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1111,7 +1111,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_gdl * gdl_loss + self.lambda_focal * focal_loss return total_loss - + class SparseDiceLoss(_Loss): """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks.