Skip to content

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Oct 20, 2025

co-authored by @leejnau

cc. @bnellnm

Purpose

For TP case, the nvfp4 quantzation is fused with flashinfer_cutlass_moe call. This fixes the accuracy issue for nvidia/Deepseek-R1-0528-FP4-v2 model.
For DP case, the fix should rely on #26135.

Test Plan

VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="throughput" \
/home/shuw/.local/bin/lm_eval --model vllm --model_args pretrained=nvidia/DeepSeek-R1-0528-FP4-v2,data_parallel_size=1,enable_expert_parallel=True,tensor_parallel_size=4,enforce_eager=True,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

Previously:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.0106 ± 0.0028
strict-match 5 exact_match 0.0000 ± 0.0000

With this PR:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9522 ± 0.0059
strict-match 5 exact_match 0.9469 ± 0.0062

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@wenscarl wenscarl marked this pull request as ready for review October 20, 2025 20:26
@wenscarl wenscarl force-pushed the flashinfer_cutlass_moe_tp branch from 1057c27 to 4c37b7d Compare October 20, 2025 20:33
Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Alternatively, you could prevent modular kernels from being created in this case and fall though to the direct call to flashinfer_cutlass_moe_fp4?

@wenscarl
Copy link
Contributor Author

LGTM. Alternatively, you could prevent modular kernels from being created in this case and fall though to the direct call to flashinfer_cutlass_moe_fp4?

Yes. But do you prefer to have modular kernels anyways?


assert self.moe_quant_config is not None

return flashinfer_cutlass_moe_fp4(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update for compressed-tensors too?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl @mgoin is my understanding correct that, going forward, direct calling into FlashInfer is going to be deprecated in favour of modular kernels, for all cases (e.g., regardless of EP/TP/DP choices)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl @mgoin is my understanding correct that, going forward, direct calling into FlashInfer is going to be deprecated in favour of modular kernels, for all cases (e.g., regardless of EP/TP/DP choices)?

There's no plan to force all MoE kernels to be called via modular kernels.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But by deleting this elif clause (and, per @mgoin 's suggestion, applying this change to compressed-tensors), doesn't it force the use of FlashInfer cutlass implementation to go through the modular kernels?

I'm just trying to understand if this is the plan for all cases that use FlashInfer, regardless of distributed strategies, or whether self.flashinfer_moe_backend is FlashinferMoeBackend.TENSORRT_LLM or FlashinferMoeBackend.CUTLASS.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to use the modular kernels for cases that aren't using some kind of all2all communication. In this particular case I think @wenscarl and @leejnau figured out that this was dead code because the CUTLASS case always created a modular kernel. I'm not sure if the same holds true for compressed_tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm this PR is updated with a additional quant_dtype: nvfp4_skip_quantization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just trying to understand if this is the plan for all cases that use FlashInfer
I vote for that. Since flashinfer cutlass moe is at least a better option to normal cutlass_moe. TRTLLM MoE can win sometimes even.

@bnellnm
Copy link
Contributor

bnellnm commented Oct 20, 2025

LGTM. Alternatively, you could prevent modular kernels from being created in this case and fall though to the direct call to flashinfer_cutlass_moe_fp4?

Yes. But do you prefer to have modular kernels anyways?

I've no preference for this particular case since there's no communication going on. My only concern would be that the FusedMoEQuantConfig doesn't match up with what the prepare/finalize class is doing. This could probably be made consistent by making sure the quant_dtype is None, e.g. something like this

    def get_fused_moe_quant_config(
        self, layer: torch.nn.Module
    ) -> FusedMoEQuantConfig | None:
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            return None
        
        # New case here for CUTLASS
        if (
              self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
              and layer.tp_size > 1
        ):  
            return FusedMoEQuantConfig.make(
               w1_scale=layer.w13_weight_scale,
               w2_scale=layer.w2_weight_scale,
               g1_alphas=layer.g1_alphas,
               g2_alphas=layer.g2_alphas,
               a1_gscale=layer.w13_input_scale_quant,
               a2_gscale=layer.w2_input_scale_quant,
           )
        
        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

@mgoin mgoin added bug Something isn't working quantization ready ONLY add when PR is ready to merge/full CI is needed labels Oct 21, 2025
@wenscarl wenscarl requested review from bnellnm and mgoin October 22, 2025 03:40
@wenscarl wenscarl force-pushed the flashinfer_cutlass_moe_tp branch 2 times, most recently from c78959c to 4c37b7d Compare October 23, 2025 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants