Skip to content

Commit 4560708

Browse files
authored
Update dice.py
1 parent 85f0e8e commit 4560708

File tree

1 file changed

+6
-30
lines changed
  • src/torchmetrics/functional/segmentation

1 file changed

+6
-30
lines changed

src/torchmetrics/functional/segmentation/dice.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Union
14+
from typing import Optional
1515

1616
import torch
1717
from torch import Tensor
@@ -28,7 +28,6 @@ def _dice_score_validate_args(
2828
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
2929
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
3030
aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
31-
zero_division: Union[float, Literal["warn", "nan"]] = "nan",
3231
) -> None:
3332
"""Validate the arguments of the metric."""
3433
if not isinstance(num_classes, int) or num_classes <= 0:
@@ -46,10 +45,6 @@ def _dice_score_validate_args(
4645
raise ValueError(
4746
f"Expected argument `aggregation_level` to be one of `samplewise`, `global`, but got {aggregation_level}"
4847
)
49-
if zero_division not in (0.0, 1.0, "warn", "nan"):
50-
raise ValueError(
51-
f"Expected argument `zero_division` to be one of 0.0, 1.0, 'warn', or 'nan', but got {zero_division}."
52-
)
5348

5449

5550
def _dice_score_update(
@@ -79,34 +74,25 @@ def _dice_score_compute(
7974
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
8075
aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
8176
support: Optional[Tensor] = None,
82-
zero_division: Union[float, Literal["warn", "nan"]] = "nan",
8377
) -> Tensor:
8478
"""Compute the Dice score from the numerator and denominator."""
8579
if aggregation_level == "global":
8680
numerator = torch.sum(numerator, dim=0).unsqueeze(0)
8781
denominator = torch.sum(denominator, dim=0).unsqueeze(0)
8882
support = torch.sum(support, dim=0) if support is not None else None
8983

90-
# Determine the zero_division value to use
91-
if zero_division == "warn":
92-
zero_div_value = "warn"
93-
elif zero_division == "nan":
94-
zero_div_value = "nan"
95-
else:
96-
zero_div_value = float(zero_division)
97-
9884
if average == "micro":
9985
numerator = torch.sum(numerator, dim=-1)
10086
denominator = torch.sum(denominator, dim=-1)
101-
return _safe_divide(numerator, denominator, zero_division=zero_div_value)
87+
return _safe_divide(numerator, denominator, zero_division="nan")
10288

103-
dice = _safe_divide(numerator, denominator, zero_division=zero_div_value)
89+
dice = _safe_divide(numerator, denominator, zero_division="nan")
10490
if average == "macro":
10591
return torch.nanmean(dice, dim=-1)
10692
if average == "weighted":
10793
if not isinstance(support, torch.Tensor):
10894
raise ValueError(f"Expected argument `support` to be a tensor, got: {type(support)}.")
109-
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=zero_div_value)
95+
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division="nan")
11096
nan_mask = dice.isnan().all(dim=-1)
11197
dice = torch.nansum(dice * weights, dim=-1)
11298
dice[nan_mask] = torch.nan
@@ -124,7 +110,6 @@ def dice_score(
124110
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
125111
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
126112
aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
127-
zero_division: Union[float, Literal["warn", "nan"]] = "nan",
128113
) -> Tensor:
129114
"""Compute the Dice score for semantic segmentation.
130115
@@ -141,8 +126,6 @@ def dice_score(
141126
aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``.
142127
For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice
143128
score is computed globally over all samples.
144-
zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan".
145-
Setting it to "warn" behaves like 0.0 but will also create a warning.
146129
147130
Returns:
148131
The Dice score.
@@ -193,13 +176,6 @@ def dice_score(
193176
" If you've explicitly set this parameter, you can ignore this warning.",
194177
UserWarning,
195178
)
196-
_dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level, zero_division)
179+
_dice_score_validate_args(num_classes, include_background, average, input_format, aggregation_level)
197180
numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)
198-
return _dice_score_compute(
199-
numerator,
200-
denominator,
201-
average,
202-
aggregation_level=aggregation_level,
203-
support=support,
204-
zero_division=zero_division,
205-
)
181+
return _dice_score_compute(numerator, denominator, average, aggregation_level=aggregation_level, support=support)

0 commit comments

Comments
 (0)