1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Optional , Union
14+ from typing import Optional
1515
1616import torch
1717from torch import Tensor
@@ -28,7 +28,6 @@ def _dice_score_validate_args(
2828 average : Optional [Literal ["micro" , "macro" , "weighted" , "none" ]] = "micro" ,
2929 input_format : Literal ["one-hot" , "index" , "mixed" ] = "one-hot" ,
3030 aggregation_level : Optional [Literal ["samplewise" , "global" ]] = "samplewise" ,
31- zero_division : Union [float , Literal ["warn" , "nan" ]] = "nan" ,
3231) -> None :
3332 """Validate the arguments of the metric."""
3433 if not isinstance (num_classes , int ) or num_classes <= 0 :
@@ -46,10 +45,6 @@ def _dice_score_validate_args(
4645 raise ValueError (
4746 f"Expected argument `aggregation_level` to be one of `samplewise`, `global`, but got { aggregation_level } "
4847 )
49- if zero_division not in (0.0 , 1.0 , "warn" , "nan" ):
50- raise ValueError (
51- f"Expected argument `zero_division` to be one of 0.0, 1.0, 'warn', or 'nan', but got { zero_division } ."
52- )
5348
5449
5550def _dice_score_update (
@@ -79,34 +74,25 @@ def _dice_score_compute(
7974 average : Optional [Literal ["micro" , "macro" , "weighted" , "none" ]] = "micro" ,
8075 aggregation_level : Optional [Literal ["samplewise" , "global" ]] = "samplewise" ,
8176 support : Optional [Tensor ] = None ,
82- zero_division : Union [float , Literal ["warn" , "nan" ]] = "nan" ,
8377) -> Tensor :
8478 """Compute the Dice score from the numerator and denominator."""
8579 if aggregation_level == "global" :
8680 numerator = torch .sum (numerator , dim = 0 ).unsqueeze (0 )
8781 denominator = torch .sum (denominator , dim = 0 ).unsqueeze (0 )
8882 support = torch .sum (support , dim = 0 ) if support is not None else None
8983
90- # Determine the zero_division value to use
91- if zero_division == "warn" :
92- zero_div_value = "warn"
93- elif zero_division == "nan" :
94- zero_div_value = "nan"
95- else :
96- zero_div_value = float (zero_division )
97-
9884 if average == "micro" :
9985 numerator = torch .sum (numerator , dim = - 1 )
10086 denominator = torch .sum (denominator , dim = - 1 )
101- return _safe_divide (numerator , denominator , zero_division = zero_div_value )
87+ return _safe_divide (numerator , denominator , zero_division = "nan" )
10288
103- dice = _safe_divide (numerator , denominator , zero_division = zero_div_value )
89+ dice = _safe_divide (numerator , denominator , zero_division = "nan" )
10490 if average == "macro" :
10591 return torch .nanmean (dice , dim = - 1 )
10692 if average == "weighted" :
10793 if not isinstance (support , torch .Tensor ):
10894 raise ValueError (f"Expected argument `support` to be a tensor, got: { type (support )} ." )
109- weights = _safe_divide (support , torch .sum (support , dim = - 1 , keepdim = True ), zero_division = zero_div_value )
95+ weights = _safe_divide (support , torch .sum (support , dim = - 1 , keepdim = True ), zero_division = "nan" )
11096 nan_mask = dice .isnan ().all (dim = - 1 )
11197 dice = torch .nansum (dice * weights , dim = - 1 )
11298 dice [nan_mask ] = torch .nan
@@ -124,7 +110,6 @@ def dice_score(
124110 average : Optional [Literal ["micro" , "macro" , "weighted" , "none" ]] = "macro" ,
125111 input_format : Literal ["one-hot" , "index" , "mixed" ] = "one-hot" ,
126112 aggregation_level : Optional [Literal ["samplewise" , "global" ]] = "samplewise" ,
127- zero_division : Union [float , Literal ["warn" , "nan" ]] = "nan" ,
128113) -> Tensor :
129114 """Compute the Dice score for semantic segmentation.
130115
@@ -141,8 +126,6 @@ def dice_score(
141126 aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``.
142127 For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice
143128 score is computed globally over all samples.
144- zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan".
145- Setting it to "warn" behaves like 0.0 but will also create a warning.
146129
147130 Returns:
148131 The Dice score.
@@ -193,13 +176,6 @@ def dice_score(
193176 " If you've explicitly set this parameter, you can ignore this warning." ,
194177 UserWarning ,
195178 )
196- _dice_score_validate_args (num_classes , include_background , average , input_format , aggregation_level , zero_division )
179+ _dice_score_validate_args (num_classes , include_background , average , input_format , aggregation_level )
197180 numerator , denominator , support = _dice_score_update (preds , target , num_classes , include_background , input_format )
198- return _dice_score_compute (
199- numerator ,
200- denominator ,
201- average ,
202- aggregation_level = aggregation_level ,
203- support = support ,
204- zero_division = zero_division ,
205- )
181+ return _dice_score_compute (numerator , denominator , average , aggregation_level = aggregation_level , support = support )
0 commit comments