Skip to content

Commit 750620b

Browse files
authored
[benchmark_inference] Reshape the output from run_routed_experts (#2650)
1 parent e207bf6 commit 750620b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

thunder/benchmarks/layers_for_inference_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def run_routed_experts(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor,
605605
token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id)
606606
outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id]
607607

608-
return outs_sorted_by_token_id, router_logits
608+
return outs_sorted_by_token_id.view(batch_size, seq_len, -1), router_logits.view(batch_size, seq_len, -1)
609609

610610
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
611611
outs_sorted_by_token_id, router_logits = self.run_routed_experts(hidden_states)

0 commit comments

Comments
 (0)