-
Notifications
You must be signed in to change notification settings - Fork 13
Support of fp8_scaled_mm() on XPU #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support of fp8_scaled_mm() on XPU #34
Conversation
* Added support for sgl_kernel.fp8_scaled_mm op * Input in dtype fp8 e4m3 or e5m2 * Output in dtype fp32, bf16, fp8 e4m3 or fp8 e5m2 Signed-off-by: Aditya Chatterjee <[email protected]>
6ae8ca6 to
f3a0a83
Compare
airMeng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you compared with OneDNN's FP8 scaled_mm, which I think we can reuse PyTorch's effort?
| set(FETCHCONTENT_MAKEAVAILABLE_SERIAL FALSE) | ||
| FetchContent_MakeAvailable(repo-cutlass-sycl) | ||
| file(COPY ${repo-cutlass-sycl_SOURCE_DIR}/cmake/onemkl.cmake | ||
| DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/cmake) | ||
| set(FETCHCONTENT_MAKEAVAILABLE_SERIAL TRUE) | ||
| FetchContent_MakeAvailable(repo-cutlass-sycl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MKL has been disabled in the latest cutlass-sycl, you can remove these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this file
|
this PR would be quite slow on current platform of intel GPUs. (even for CRI i believe it requires quite a lot of change to be performant). Are you planning to provide functional support here? @adityachatter |
@mingfeima, the target is mostly functional here. Yes CRI will have optimal solution for any fp8 support. |
@kareemshaik80 OK I see. Please put this on a developing branch, maybe named after addtionally, these are a few APIs mismatches with sglang:
|
right, this is mainly for BMG here, will evaluate performance. by the way per block quantization/scale is different api will have different implementation. |
| float beta = 0.0f; | ||
|
|
||
| // Create a dummy C tensor | ||
| cutlass::device_memory::allocation<ElementC> dummy_C(M * N); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid direct memory allocation from sycl runtime, use torch factory function.
| {static_cast<ElementA*>(mat_a.data_ptr()), | ||
| stride_A, | ||
| static_cast<ElementB*>(mat_b.data_ptr()), | ||
| stride_B, | ||
| static_cast<ElementScale*>(scales_a.data_ptr()), | ||
| stride_SA, | ||
| static_cast<ElementScale*>(scales_b.data_ptr()), | ||
| stride_SB, | ||
| nullptr, | ||
| stride_SA, // No zero point for A | ||
| nullptr, | ||
| stride_SB, // No zero point for B | ||
| K}, // group_size = K for per-row/col scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lint
| size_t workspace_size = Gemm::get_workspace_size(arguments); | ||
| cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as above.
| static inline std::pair<float, float> get_fp8_range(at::ScalarType dtype) { | ||
| if (dtype == at::ScalarType::Float8_e4m3fn) { | ||
| // E4M3FN: max = 448, min = -448 | ||
| return {-448.0f, 448.0f}; | ||
| } else { | ||
| // Float8_e5m2 | ||
| // E5M2: max = 57344, min = -57344 | ||
| return {-57344.0f, 57344.0f}; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should have been covered in torch, aten have overloaded std::numeric_limits
| if (out_dtype == at::ScalarType::BFloat16) { | ||
| using Config = Fp8GemmConfig<ElementInputFp8, cutlass::bfloat16_t>; | ||
| Fp8GemmRunner<typename Config::Gemm, cutlass::bfloat16_t> runner; | ||
| status = runner.run(mat_a_contig, mat_b_contig, scales_a_half, scales_b_half, out, hw_info); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in sglang, you can only implement bfloat16. out data type is bfloat16 or float16.
| at::ScalarType intermediate_dtype; | ||
| if (is_fp8_dtype(out_dtype)) { | ||
| intermediate_dtype = at::ScalarType::Half; | ||
| } else { | ||
| intermediate_dtype = out_dtype; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed.
| // Dispatch based on input FP8 type | ||
| if (input_dtype == at::ScalarType::Float8_e4m3fn) { | ||
| fp8_scaled_mm_impl<cutlass::float_e4m3_t>( | ||
| mat_a, mat_b, scales_a_half, scales_b_half, intermediate_dtype, out_intermediate, hw_info); | ||
| } else { | ||
| fp8_scaled_mm_impl<cutlass::float_e5m2_t>( | ||
| mat_a, mat_b, scales_a_half, scales_b_half, intermediate_dtype, out_intermediate, hw_info); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it pytorch-like, use AT_DISPATCH_xxx macros.
if it is not available, make one of your own demand, you can also define other types in it, such as acc_scalar_t and so on.
| TORCH_CHECK(bias_tensor.size(0) == N, "bias must have size N"); | ||
| TORCH_CHECK(bias_tensor.is_contiguous(), "bias must be contiguous"); | ||
|
|
||
| if (is_fp8_dtype(out_dtype)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need this.
| @@ -0,0 +1,124 @@ | |||
| /*************************************************************************************************** | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicated.
| """ | ||
| Test code for sgl_kernel.fp8_scaled_mm() | ||
| Run as: | ||
| python -m pytest -v -s test_fp8_scaled_mm_xpu.py | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | |
| Test code for sgl_kernel.fp8_scaled_mm() | |
| Run as: | |
| python -m pytest -v -s test_fp8_scaled_mm_xpu.py | |
| """ |
OK, per channel quantization is not welcome for recently released LLMs. Anyway, please provide performance data on battlemage. |
airMeng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you update the CI and benchmarks
sgl-kernel-xpu/tests/run_suite.py
Line 19 in 1bb6c78
| TestFile("test_flash_attention.py"), |
sgl-kernel-xpu/benchmark/bench_fp8_gemm.py
Line 117 in 1bb6c78
| lambda: sgl_scaled_mm( |
| /bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py " |
Add support for op
sgl_kernel.fp8_scaled_mm()on XPU.Supports:
FP8 E4M3orFP8 E5M2BF16,FP32,FP8 E4M3,FP8 E5M2Run the fp8_scaled_mm test code as:
Tested on BMG B580:
2000 passedfp8_scaled_mmdesigned for FP8 DeepSeek inference requirement.