@@ -387,9 +387,10 @@ def __init__(self, groups: int, in_features: int, out_features: int, dtype: torc
387387 self .weight = nn .Parameter (torch .empty (groups , out_features , in_features , dtype = dtype , device = device ))
388388 # Initialize the weight in the same way as nn.Linear
389389 nn .init .kaiming_uniform_ (self .weight , a = math .sqrt (5 ))
390+ self .weight .data = self .weight .transpose (- 1 , - 2 )
390391
391392 def forward (self , hidden_states : torch .Tensor , offsets : torch .Tensor ) -> torch .Tensor :
392- return grouped_mm (hidden_states , self .weight . transpose ( - 1 , - 2 ) , offsets )
393+ return grouped_mm (hidden_states , self .weight , offsets )
393394
394395
395396@torch .inference_mode ()
@@ -631,13 +632,13 @@ def from_transformers_llama4textmoe(moe: Llama4TextMoe) -> Llama4MoE:
631632 # Split into gate and up projections
632633 gate_proj_w , up_proj_w = moe .experts .gate_up_proj .chunk (2 , dim = 2 )
633634
634- new_moe .routed_experts .gate_proj .weight .data .copy_ (gate_proj_w . transpose ( - 1 , - 2 ) )
635- new_moe .routed_experts .up_proj .weight .data .copy_ (up_proj_w . transpose ( - 1 , - 2 ) )
635+ new_moe .routed_experts .gate_proj .weight .data .copy_ (gate_proj_w )
636+ new_moe .routed_experts .up_proj .weight .data .copy_ (up_proj_w )
636637
637638 # Handle down_proj
638639 # HF format: (groups, intermediate_size, hidden_size)
639640 # Our format: (groups, hidden, intermediate_size)
640- new_moe .routed_experts .down_proj .weight .data .copy_ (moe .experts .down_proj . transpose ( - 1 , - 2 ) )
641+ new_moe .routed_experts .down_proj .weight .data .copy_ (moe .experts .down_proj )
641642
642643 return new_moe
643644
0 commit comments