Skip to content

Commit 3585f64

Browse files
committed
final fix
1 parent 8697207 commit 3585f64

File tree

2 files changed

+1
-2
lines changed

2 files changed

+1
-2
lines changed

MaxText/convert_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,4 @@ def main(args):
272272
parser.add_argument("--use-zarr3", type=str2bool, default=True, help="Use Zarr3 format for saving.")
273273

274274
parsed_args = parser.parse_args()
275-
main(parsed_args)
275+
main(parsed_args)

MaxText/layers/moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ def expert_group_mask(self, gate_logits: jax.Array) -> jax.Array:
382382

383383
# Mask selected groups so that only those experts are considered.
384384
group_mask = jax.nn.one_hot(group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32)
385-
group_mask = jax.nn.one_hot(group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32)
386385
group_mask = jnp.sum(group_mask, axis=-2)
387386

388387
# Apply masks and get top-k indices.

0 commit comments

Comments
 (0)