Skip to content

[cutlass] support fp8/int4/mxfp4 weights grouped gemm#88

Merged
mayuyuace merged 15 commits intovllm-project:mainfrom
mayuyuace:qiming/wint4a16
Dec 10, 2025
Merged

[cutlass] support fp8/int4/mxfp4 weights grouped gemm#88
mayuyuace merged 15 commits intovllm-project:mainfrom
mayuyuace:qiming/wint4a16

Conversation

@mayuyuace
Copy link
Collaborator

@mayuyuace mayuyuace commented Dec 1, 2025

Based on cutlass repo.
https://github.com/intel/sycl-tla/blob/main/examples/cute/tutorial/xe_gemm.cpp
Some tensor casting code refer to examples/12_bmg_moe_gemm_cute_interface like tCrC_final and make_moe_tensor.

bf16/fp16/fp8 require [K, N] shaped weights.
mxfp4/int4 require [N, K] shaped weights.

Prefill stage:
With [E, M, N, K] = [16, 8192, 5120, 8192], topk=1, during time as below:
bf16: 8.470ms, 81.132 TFLOPS
fp16: 8.484ms, 80.998 TFLOPS
W fp8 e5m2 A bf16 : 8.961ms, 76.687 TFLOPS
W fp8 e5m2 A fp16: 8.839ms, 77.745 TFLOPS
W fp8 e4m3 A bf16 : 9.341ms, 73.567 TFLOPS
W fp8 e4m3 A fp16: 9.128ms, 75.284 TFLOPS
W int4 A bf16: 12.298ms , 55.878 TFLOPS
W int4 A fp16: 8.845ms, 77.693 TFLOPS
W fp4 A bf16: 13.516ms, 50.843 TFLOPS
W fp4 A fp16: 10.848ms, 63.347 TFLOPS

Decode stage:
Note that rows of each gemm may be 0, because num_rows_per_expert is generated by randn.
1 of 16 gemm_m is 0.
So real BW is smaller than below bandwidth data about 6 %.
With [E, M, N, K] = [16, 32, 5120, 8192], topk=1, during time as below:
bf16: 2.8157ms, 476.978 GB/s
fp16: 2.8888ms, 464.909 GB/s
W fp8 e5m2 A bf16 : 1.4123ms, 475.777 GB/s
W fp8 e5m2 A fp16: 1.4036ms, 478.726 GB/s
W fp8 e4m3 A bf16 : 1.4136ms, 475.339 GB/s
W fp8 e4m3 A fp16: 1.4202ms, 473.130 GB/s
W int4 A bf16: 0.8267ms, 406.914 GB/s
W int4 A fp16: 0.7755ms, 433.779 GB/s
W fp4 A bf16: 0.8221ms, 409.191 GB/s
W fp4 A fp16: 0.7070ms, 475.808 GB/s

BTW, with [E, M, N, K] = [32, 32, 3072 * 2, 3072], topk=4 (GPT-OSS shape)
W int4 A fp16 has same during time with Xetla: 737.9 us vs 738.645 us (Xetla)
W fp4 A bf16 has better performance with Xetla: 822.1 us vs 1401.354 us (Xetla)

After this PR, I will continue to optimize the grouped gemm kernel.
And next PR will refactor these files to collective format.

Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Copilot AI review requested due to automatic review settings December 1, 2025 08:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for fp8/int4/mxfp4 weight quantization formats to grouped GEMM operations, along with a new cutlass_xe_grouped_gemm operation that provides better performance than the existing implementation for bf16/fp16 data types. The implementation is based on the Intel CUTLASS SYCL-TLA repository.

Key changes:

  • Added new cutlass_xe_grouped_gemm function supporting multiple weight quantization formats
  • Implemented fp8, int4, and mxfp4 weight formats with appropriate dequantization and scaling
  • Added comprehensive test coverage for all new data type combinations

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
vllm_xpu_kernels/fused_moe_interface.py Adds Python wrapper for new cutlass_xe_grouped_gemm operation
tests/fused_moe/test_fused_moe.py Adds test cases for bf16/fp16, fp8, int4, and mxfp4 weight formats
csrc/xpu/torch_bindings.cpp Registers new cutlass_xe_grouped_gemm operation with PyTorch
csrc/xpu/ops.h Declares function signature for cutlass_xe_grouped_gemm
csrc/xpu/cutlass_kernels/grouped_gemm/xe_grouped_gemm.hpp Implements MoE grouped GEMM kernel with support for quantized weights
csrc/xpu/cutlass_kernels/grouped_gemm/xe_gemm.hpp Implements core GEMM functions for standard and 4-bit quantized operations
csrc/xpu/cutlass_kernels/grouped_gemm/grouped_gemm_interface.hpp Provides interface and dispatch logic for different data type combinations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
@Liangliang-Ma
Copy link
Collaborator

Would you refactor this one to collective style? If current tile_scheduler and epilogue can be improved we better modified it or reduce redundant logic in it.

@pengzhao-intel
Copy link
Collaborator

Please also try meomory bound input shape for low precision (fp8, fp4, etc.)

@mayuyuace
Copy link
Collaborator Author

mayuyuace commented Dec 2, 2025

Would you refactor this one to collective style? If current tile_scheduler and epilogue can be improved we better modified it or reduce redundant logic in it.

I think it is unnecessary to use only collective style. In fact, some examples in sycl-tla also do not use collective code to realize function.
Collective is only some template for users who don't want to focus on the details.
I think it is more convenient to use cute interface only.

@pengzhao-intel
Copy link
Collaborator

Would you refactor this one to collective style? If current tile_scheduler and epilogue can be improved we better modified it or reduce redundant logic in it.

I think it is unnecessary to use only collective style. In fact, some examples in sycl-tla also do not use collective code to realize function. Collective is only some template for users who don't want to focus on the details. I think it is more convenient to use cute interface only.

@Liangliang-Ma I think we can fine tune performance first and organize code to collective afterthen.

Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
if (is_B_int4) { \
if (A_dtype == at::kBFloat16) { \
using scalar_t = bfloat16_t; \
MoEGEMMLauncherCallER('R', 'C', policy, scalar_t, uint8_t, scalar_t); \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does 4bit dtype require column-major weight?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the raw data provided by vllm is k-major.
It is hard to convert 4 bits k-major matrix to n-major.
Image u8 store 2x4bits data, so just transposing 8 bits [n, k / 2] to [k / 2, n] is transposing 4 bits [n, k] to [k / 2, n , 2].
If needed, we can try to realize real 4 bits transpose and R-major moe gemm.

reinterpret_cast<ElementB*>(ptr_B.data_ptr()), \
ptr_scales.has_value() \
? reinterpret_cast<ElementS*>(ptr_scales->data_ptr()) \
: static_cast<ElementS*>(nullptr), \
Copy link
Collaborator

@Liangliang-Ma Liangliang-Ma Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may not cast the nullptr here. Will it report compilation error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes it will report error here.
I can try to remove the static cast then.

Copy link
Collaborator

@Liangliang-Ma Liangliang-Ma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OverLGTM. Afterwards pls modify fused_moe_interface too to enable e2e run.

Copy link
Collaborator

@pengzhao-intel pengzhao-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good works. Also please update the latest performance in PR and let's see the e2e performance and compared with previous XeTLA perf as well.

@mayuyuace mayuyuace merged commit 29c5ee0 into vllm-project:main Dec 10, 2025
4 checks passed
@mayuyuace
Copy link
Collaborator Author

OverLGTM. Afterwards pls modify fused_moe_interface too to enable e2e run.

Will be completed together with collective format refactor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants