File tree Expand file tree Collapse file tree 1 file changed +3
-11
lines changed Expand file tree Collapse file tree 1 file changed +3
-11
lines changed Original file line number Diff line number Diff line change @@ -127,24 +127,16 @@ def dice_loss(
127127 )
128128
129129 # Convert logits to probabilities
130+ probs = predictions
130131 if apply_softmax :
131132 probs = (
132133 jax .nn .sigmoid (predictions )
133134 if predictions .shape [- 1 ] == 1
134135 else jax .nn .softmax (predictions , axis = - 1 )
135136 )
136- else :
137- probs = predictions
138-
139- # Determine which axes to sum over for computing the loss
140- if axis is None :
141- # Default behavior: sum over all spatial dimensions (except first/last)
142- axis = tuple (range (1 , probs .ndim - 1 ))
143- elif isinstance (axis , int ):
144- axis = (axis ,)
145137
146- # Ensure axis is a tuple of non-negative integers
147- axis = tuple (ax % probs .ndim for ax in axis )
138+ # Default behavior: sum over all spatial dimensions (except first/last)
139+ axis = tuple (range ( 1 , probs .ndim - 1 )) if axis is None else axis
148140
149141 # Compute intersection and sums over specified axes
150142 intersection = jnp .sum (probs * targets , axis = axis )
You can’t perform that action at this time.
0 commit comments