[cutlass] support fp8/int4/mxfp4 weights grouped gemm#88
[cutlass] support fp8/int4/mxfp4 weights grouped gemm#88mayuyuace merged 15 commits intovllm-project:mainfrom
Conversation
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR adds support for fp8/int4/mxfp4 weight quantization formats to grouped GEMM operations, along with a new cutlass_xe_grouped_gemm operation that provides better performance than the existing implementation for bf16/fp16 data types. The implementation is based on the Intel CUTLASS SYCL-TLA repository.
Key changes:
- Added new
cutlass_xe_grouped_gemmfunction supporting multiple weight quantization formats - Implemented fp8, int4, and mxfp4 weight formats with appropriate dequantization and scaling
- Added comprehensive test coverage for all new data type combinations
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_xpu_kernels/fused_moe_interface.py | Adds Python wrapper for new cutlass_xe_grouped_gemm operation |
| tests/fused_moe/test_fused_moe.py | Adds test cases for bf16/fp16, fp8, int4, and mxfp4 weight formats |
| csrc/xpu/torch_bindings.cpp | Registers new cutlass_xe_grouped_gemm operation with PyTorch |
| csrc/xpu/ops.h | Declares function signature for cutlass_xe_grouped_gemm |
| csrc/xpu/cutlass_kernels/grouped_gemm/xe_grouped_gemm.hpp | Implements MoE grouped GEMM kernel with support for quantized weights |
| csrc/xpu/cutlass_kernels/grouped_gemm/xe_gemm.hpp | Implements core GEMM functions for standard and 4-bit quantized operations |
| csrc/xpu/cutlass_kernels/grouped_gemm/grouped_gemm_interface.hpp | Provides interface and dispatch logic for different data type combinations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
c26b623 to
b8847c9
Compare
|
Would you refactor this one to collective style? If current tile_scheduler and epilogue can be improved we better modified it or reduce redundant logic in it. |
|
Please also try meomory bound input shape for low precision (fp8, fp4, etc.) |
csrc/xpu/cutlass_kernels/grouped_gemm/grouped_gemm_interface.hpp
Outdated
Show resolved
Hide resolved
I think it is unnecessary to use only collective style. In fact, some examples in sycl-tla also do not use collective code to realize function. |
@Liangliang-Ma I think we can fine tune performance first and organize code to collective afterthen. |
csrc/xpu/cutlass_kernels/grouped_gemm/grouped_gemm_interface.hpp
Outdated
Show resolved
Hide resolved
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
| if (is_B_int4) { \ | ||
| if (A_dtype == at::kBFloat16) { \ | ||
| using scalar_t = bfloat16_t; \ | ||
| MoEGEMMLauncherCallER('R', 'C', policy, scalar_t, uint8_t, scalar_t); \ |
There was a problem hiding this comment.
why does 4bit dtype require column-major weight?
There was a problem hiding this comment.
Because the raw data provided by vllm is k-major.
It is hard to convert 4 bits k-major matrix to n-major.
Image u8 store 2x4bits data, so just transposing 8 bits [n, k / 2] to [k / 2, n] is transposing 4 bits [n, k] to [k / 2, n , 2].
If needed, we can try to realize real 4 bits transpose and R-major moe gemm.
| reinterpret_cast<ElementB*>(ptr_B.data_ptr()), \ | ||
| ptr_scales.has_value() \ | ||
| ? reinterpret_cast<ElementS*>(ptr_scales->data_ptr()) \ | ||
| : static_cast<ElementS*>(nullptr), \ |
There was a problem hiding this comment.
We may not cast the nullptr here. Will it report compilation error?
There was a problem hiding this comment.
Sometimes it will report error here.
I can try to remove the static cast then.
Liangliang-Ma
left a comment
There was a problem hiding this comment.
OverLGTM. Afterwards pls modify fused_moe_interface too to enable e2e run.
pengzhao-intel
left a comment
There was a problem hiding this comment.
It's good works. Also please update the latest performance in PR and let's see the e2e performance and compared with previous XeTLA perf as well.
Will be completed together with collective format refactor. |
Based on cutlass repo.
https://github.com/intel/sycl-tla/blob/main/examples/cute/tutorial/xe_gemm.cpp
Some tensor casting code refer to examples/12_bmg_moe_gemm_cute_interface like tCrC_final and make_moe_tensor.
bf16/fp16/fp8 require [K, N] shaped weights.
mxfp4/int4 require [N, K] shaped weights.
Prefill stage:
With [E, M, N, K] = [16, 8192, 5120, 8192], topk=1, during time as below:
bf16: 8.470ms, 81.132 TFLOPS
fp16: 8.484ms, 80.998 TFLOPS
W fp8 e5m2 A bf16 : 8.961ms, 76.687 TFLOPS
W fp8 e5m2 A fp16: 8.839ms, 77.745 TFLOPS
W fp8 e4m3 A bf16 : 9.341ms, 73.567 TFLOPS
W fp8 e4m3 A fp16: 9.128ms, 75.284 TFLOPS
W int4 A bf16: 12.298ms , 55.878 TFLOPS
W int4 A fp16: 8.845ms, 77.693 TFLOPS
W fp4 A bf16: 13.516ms, 50.843 TFLOPS
W fp4 A fp16: 10.848ms, 63.347 TFLOPS
Decode stage:
Note that rows of each gemm may be 0, because num_rows_per_expert is generated by randn.
1 of 16 gemm_m is 0.
So real BW is smaller than below bandwidth data about 6 %.
With [E, M, N, K] = [16, 32, 5120, 8192], topk=1, during time as below:
bf16: 2.8157ms, 476.978 GB/s
fp16: 2.8888ms, 464.909 GB/s
W fp8 e5m2 A bf16 : 1.4123ms, 475.777 GB/s
W fp8 e5m2 A fp16: 1.4036ms, 478.726 GB/s
W fp8 e4m3 A bf16 : 1.4136ms, 475.339 GB/s
W fp8 e4m3 A fp16: 1.4202ms, 473.130 GB/s
W int4 A bf16: 0.8267ms, 406.914 GB/s
W int4 A fp16: 0.7755ms, 433.779 GB/s
W fp4 A bf16: 0.8221ms, 409.191 GB/s
W fp4 A fp16: 0.7070ms, 475.808 GB/s
BTW, with [E, M, N, K] = [32, 32, 3072 * 2, 3072], topk=4 (GPT-OSS shape)
W int4 A fp16 has same during time with Xetla: 737.9 us vs 738.645 us (Xetla)
W fp4 A bf16 has better performance with Xetla: 822.1 us vs 1401.354 us (Xetla)
After this PR, I will continue to optimize the grouped gemm kernel.
And next PR will refactor these files to collective format.