Skip to content

Commit 00d0877

Browse files
committed
Runable, ibcorrect output
1 parent 1ab07eb commit 00d0877

File tree

3 files changed

+374
-120
lines changed

3 files changed

+374
-120
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def use_deepep_ll_kernels(self):
192192
@property
193193
def use_flashinfer_cutlass_kernels(self):
194194
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
195-
and has_flashinfer_cutlass_fused_moe())
195+
and has_flashinfer_cutlass_fused_moe()
196+
and envs.VLLM_FLASHINFER_MOE_BACKEND=="flashinfer_moe_high_throughput")
196197

197198
@staticmethod
198199
def make(tp_size_: int, dp_size_: int,

vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1111
TopKWeightAndReduceDelegate)
1212
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
13-
from vllm.utils.flashinfer import (has_flashinfer_trtllm_fused_moe,
14-
trtllm_fp4_block_scale_moe)
13+
from vllm.utils.flashinfer import (has_flashinfer_trtllm_fused_moe)#,
14+
# trtllm_fp4_block_scale_moe)
1515

1616
logger = init_logger(__name__)
1717

@@ -180,47 +180,49 @@ def apply(
180180
w2_scale.view(torch.int32),
181181
g2_alphas,
182182
]
183-
out = trtllm_fp4_block_scale_moe(
184-
routing_logits,
185-
routing_bias,
186-
hidden_states,
187-
topk_ids.to(torch.int),
188-
topk_weights,
189-
# FlashInfer API requires weight to be long for nvfp4
190-
w1.view(torch.long),
191-
w2.view(torch.long),
192-
output_dtype=out_dtype,
193-
quant_scales=quant_scales,
194-
input_sf=a1q_scale,
195-
tp_size=self.tp_size,
196-
tp_rank=self.tp_rank,
197-
ep_size=self.ep_size,
198-
ep_rank=self.ep_rank,
199-
output=output,
200-
)
183+
# out = trtllm_fp4_block_scale_moe(
184+
# routing_logits,
185+
# routing_bias,
186+
# hidden_states,
187+
# topk_ids.to(torch.int),
188+
# topk_weights,
189+
# # FlashInfer API requires weight to be long for nvfp4
190+
# w1.view(torch.long),
191+
# w2.view(torch.long),
192+
# output_dtype=out_dtype,
193+
# quant_scales=quant_scales,
194+
# input_sf=a1q_scale,
195+
# tp_size=self.tp_size,
196+
# tp_rank=self.tp_rank,
197+
# ep_size=self.ep_size,
198+
# ep_rank=self.ep_rank,
199+
# output=output,
200+
# )
201+
out = output
201202
output.copy_(out)
202-
203-
return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
204-
routing_logits,
205-
routing_bias,
206-
hidden_states,
207-
hidden_states_scale,
208-
gemm1_weights,
209-
gemm1_weights_scale,
210-
gemm2_weights,
211-
gemm2_weights_scale,
212-
output1_scale_scalar,
213-
output1_scale_gate_scalar,
214-
output2_scale_scalar,
215-
num_experts,
216-
top_k,
217-
n_group,
218-
topk_group,
219-
intermediate_size,
220-
local_expert_offset,
221-
local_num_experts,
222-
routed_scaling_factor,
223-
tile_tokens_dim,
224-
routing_method_type,
225-
do_finalize,
226-
)
203+
return None
204+
205+
# return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
206+
# routing_logits,
207+
# routing_bias,
208+
# hidden_states,
209+
# hidden_states_scale,
210+
# gemm1_weights,
211+
# gemm1_weights_scale,
212+
# gemm2_weights,
213+
# gemm2_weights_scale,
214+
# output1_scale_scalar,
215+
# output1_scale_gate_scalar,
216+
# output2_scale_scalar,
217+
# num_experts,
218+
# top_k,
219+
# n_group,
220+
# topk_group,
221+
# intermediate_size,
222+
# local_expert_offset,
223+
# local_num_experts,
224+
# routed_scaling_factor,
225+
# tile_tokens_dim,
226+
# routing_method_type,
227+
# do_finalize,
228+
# )

0 commit comments

Comments
 (0)