Skip to content

Commit ac4d41d

Browse files
rdyroOptaxDev
authored andcommitted
fix CI by fixing pylint errors
PiperOrigin-RevId: 799659506
1 parent f13212d commit ac4d41d

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

optax/losses/_segmentation.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,24 +127,16 @@ def dice_loss(
127127
)
128128

129129
# Convert logits to probabilities
130+
probs = predictions
130131
if apply_softmax:
131132
probs = (
132133
jax.nn.sigmoid(predictions)
133134
if predictions.shape[-1] == 1
134135
else jax.nn.softmax(predictions, axis=-1)
135136
)
136-
else:
137-
probs = predictions
138-
139-
# Determine which axes to sum over for computing the loss
140-
if axis is None:
141-
# Default behavior: sum over all spatial dimensions (except first/last)
142-
axis = tuple(range(1, probs.ndim - 1))
143-
elif isinstance(axis, int):
144-
axis = (axis,)
145137

146-
# Ensure axis is a tuple of non-negative integers
147-
axis = tuple(ax % probs.ndim for ax in axis)
138+
# Default behavior: sum over all spatial dimensions (except first/last)
139+
axis = tuple(range(1, probs.ndim - 1)) if axis is None else axis
148140

149141
# Compute intersection and sums over specified axes
150142
intersection = jnp.sum(probs * targets, axis=axis)

0 commit comments

Comments
 (0)