Skip to content

Commit 2652cb1

Browse files
committed
fix more transpose on the weight
1 parent 01635f3 commit 2652cb1

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

thunder/benchmarks/layers_for_inference_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,10 @@ def __init__(self, groups: int, in_features: int, out_features: int, dtype: torc
387387
self.weight = nn.Parameter(torch.empty(groups, out_features, in_features, dtype=dtype, device=device))
388388
# Initialize the weight in the same way as nn.Linear
389389
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
390+
self.weight.data = self.weight.transpose(-1, -2)
390391

391392
def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
392-
return grouped_mm(hidden_states, self.weight.transpose(-1, -2), offsets)
393+
return grouped_mm(hidden_states, self.weight, offsets)
393394

394395

395396
@torch.inference_mode()
@@ -631,13 +632,13 @@ def from_transformers_llama4textmoe(moe: Llama4TextMoe) -> Llama4MoE:
631632
# Split into gate and up projections
632633
gate_proj_w, up_proj_w = moe.experts.gate_up_proj.chunk(2, dim=2)
633634

634-
new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w.transpose(-1, -2))
635-
new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w.transpose(-1, -2))
635+
new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w)
636+
new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w)
636637

637638
# Handle down_proj
638639
# HF format: (groups, intermediate_size, hidden_size)
639640
# Our format: (groups, hidden, intermediate_size)
640-
new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj.transpose(-1, -2))
641+
new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj)
641642

642643
return new_moe
643644

0 commit comments

Comments
 (0)