|
10 | 10 | from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
11 | 11 | TopKWeightAndReduceDelegate)
|
12 | 12 | from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
13 |
| -from vllm.utils.flashinfer import (has_flashinfer_trtllm_fused_moe, |
14 |
| - trtllm_fp4_block_scale_moe) |
| 13 | +from vllm.utils.flashinfer import (has_flashinfer_trtllm_fused_moe)#, |
| 14 | + # trtllm_fp4_block_scale_moe) |
15 | 15 |
|
16 | 16 | logger = init_logger(__name__)
|
17 | 17 |
|
@@ -180,47 +180,49 @@ def apply(
|
180 | 180 | w2_scale.view(torch.int32),
|
181 | 181 | g2_alphas,
|
182 | 182 | ]
|
183 |
| - out = trtllm_fp4_block_scale_moe( |
184 |
| - routing_logits, |
185 |
| - routing_bias, |
186 |
| - hidden_states, |
187 |
| - topk_ids.to(torch.int), |
188 |
| - topk_weights, |
189 |
| - # FlashInfer API requires weight to be long for nvfp4 |
190 |
| - w1.view(torch.long), |
191 |
| - w2.view(torch.long), |
192 |
| - output_dtype=out_dtype, |
193 |
| - quant_scales=quant_scales, |
194 |
| - input_sf=a1q_scale, |
195 |
| - tp_size=self.tp_size, |
196 |
| - tp_rank=self.tp_rank, |
197 |
| - ep_size=self.ep_size, |
198 |
| - ep_rank=self.ep_rank, |
199 |
| - output=output, |
200 |
| - ) |
| 183 | + # out = trtllm_fp4_block_scale_moe( |
| 184 | + # routing_logits, |
| 185 | + # routing_bias, |
| 186 | + # hidden_states, |
| 187 | + # topk_ids.to(torch.int), |
| 188 | + # topk_weights, |
| 189 | + # # FlashInfer API requires weight to be long for nvfp4 |
| 190 | + # w1.view(torch.long), |
| 191 | + # w2.view(torch.long), |
| 192 | + # output_dtype=out_dtype, |
| 193 | + # quant_scales=quant_scales, |
| 194 | + # input_sf=a1q_scale, |
| 195 | + # tp_size=self.tp_size, |
| 196 | + # tp_rank=self.tp_rank, |
| 197 | + # ep_size=self.ep_size, |
| 198 | + # ep_rank=self.ep_rank, |
| 199 | + # output=output, |
| 200 | + # ) |
| 201 | + out = output |
201 | 202 | output.copy_(out)
|
202 |
| - |
203 |
| - return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( |
204 |
| - routing_logits, |
205 |
| - routing_bias, |
206 |
| - hidden_states, |
207 |
| - hidden_states_scale, |
208 |
| - gemm1_weights, |
209 |
| - gemm1_weights_scale, |
210 |
| - gemm2_weights, |
211 |
| - gemm2_weights_scale, |
212 |
| - output1_scale_scalar, |
213 |
| - output1_scale_gate_scalar, |
214 |
| - output2_scale_scalar, |
215 |
| - num_experts, |
216 |
| - top_k, |
217 |
| - n_group, |
218 |
| - topk_group, |
219 |
| - intermediate_size, |
220 |
| - local_expert_offset, |
221 |
| - local_num_experts, |
222 |
| - routed_scaling_factor, |
223 |
| - tile_tokens_dim, |
224 |
| - routing_method_type, |
225 |
| - do_finalize, |
226 |
| - ) |
| 203 | + return None |
| 204 | + |
| 205 | + # return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( |
| 206 | + # routing_logits, |
| 207 | + # routing_bias, |
| 208 | + # hidden_states, |
| 209 | + # hidden_states_scale, |
| 210 | + # gemm1_weights, |
| 211 | + # gemm1_weights_scale, |
| 212 | + # gemm2_weights, |
| 213 | + # gemm2_weights_scale, |
| 214 | + # output1_scale_scalar, |
| 215 | + # output1_scale_gate_scalar, |
| 216 | + # output2_scale_scalar, |
| 217 | + # num_experts, |
| 218 | + # top_k, |
| 219 | + # n_group, |
| 220 | + # topk_group, |
| 221 | + # intermediate_size, |
| 222 | + # local_expert_offset, |
| 223 | + # local_num_experts, |
| 224 | + # routed_scaling_factor, |
| 225 | + # tile_tokens_dim, |
| 226 | + # routing_method_type, |
| 227 | + # do_finalize, |
| 228 | + # ) |
0 commit comments