[FusedKernel]add layer norm quant fusion kernel#180
[FusedKernel]add layer norm quant fusion kernel#180jikunshang wants to merge 4 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
There was a problem hiding this comment.
Pull request overview
Adds XPU/SYCL fused RMSNorm + static FP8 quantization ops and corresponding unit tests, wiring the new kernels into the C++ extension build and Torch op registry.
Changes:
- Registers two new Torch ops:
rms_norm_static_fp8_quantandfused_add_rms_norm_static_fp8_quant. - Adds unit tests validating correctness, strided inputs, and
opcheckcompliance. - Updates the extension build sources to include a new
layernorm_quant.cppcompilation unit.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
tests/test_layernorm_quant.py |
New test suite for the fused RMSNorm + static FP8 quant ops (incl. opcheck + strided input coverage). |
csrc/torch_bindings.cpp |
Registers the two new ops in the _C Torch library for XPU. |
csrc/ops.h |
Declares the new kernel entry points. |
CMakeLists.txt |
Adds csrc/layernorm_quant.cpp to the _C extension sources. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "rms_norm_static_fp8_quant(Tensor! result, Tensor! input, " | ||
| "Tensor! weight, Tensor! scale, float epsilon) -> ()"); | ||
| ops.impl( | ||
| "rms_norm_static_fp8_quant", torch::kXPU, &rms_norm_static_fp8_quant); | ||
|
|
||
| // Fused Add + RMS Normalization + static FP8 quantization. | ||
| ops.def( | ||
| "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor! input, " | ||
| "Tensor! residual, Tensor! weight, Tensor! scale, float epsilon) -> ()"); |
There was a problem hiding this comment.
The op schema marks input, weight, and scale as mutable (Tensor!). Unless the kernel actually mutates these tensors, they should be registered as immutable (Tensor) to avoid incorrect aliasing/mutation semantics (and to align with nearby FP8 quant ops like static_scaled_fp8_quant(Tensor input, Tensor scale, ...)). Consider also updating the C++ signature to take const Tensor& for read-only args.
| "rms_norm_static_fp8_quant(Tensor! result, Tensor! input, " | |
| "Tensor! weight, Tensor! scale, float epsilon) -> ()"); | |
| ops.impl( | |
| "rms_norm_static_fp8_quant", torch::kXPU, &rms_norm_static_fp8_quant); | |
| // Fused Add + RMS Normalization + static FP8 quantization. | |
| ops.def( | |
| "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor! input, " | |
| "Tensor! residual, Tensor! weight, Tensor! scale, float epsilon) -> ()"); | |
| "rms_norm_static_fp8_quant(Tensor! result, Tensor input, " | |
| "Tensor weight, Tensor scale, float epsilon) -> ()"); | |
| ops.impl( | |
| "rms_norm_static_fp8_quant", torch::kXPU, &rms_norm_static_fp8_quant); | |
| // Fused Add + RMS Normalization + static FP8 quantization. | |
| ops.def( | |
| "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " | |
| "Tensor residual, Tensor weight, Tensor scale, float epsilon) -> ()"); |
| ops.def( | ||
| "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor! input, " | ||
| "Tensor! residual, Tensor! weight, Tensor! scale, float epsilon) -> ()"); | ||
| ops.impl( | ||
| "fused_add_rms_norm_static_fp8_quant", | ||
| torch::kXPU, | ||
| &fused_add_rms_norm_static_fp8_quant); |
There was a problem hiding this comment.
Similarly here, input, weight, and scale are registered as mutable (Tensor!). If only residual (and result) are written, the schema should reflect that (e.g., Tensor input, Tensor! residual, Tensor weight, Tensor scale, ...) so functionalization/opcheck and future compiler passes see accurate mutation behavior.
| torch::Tensor& input, | ||
| torch::Tensor& weight, | ||
| torch::Tensor& scale, | ||
| double epsilon); | ||
|
|
||
| void fused_add_rms_norm_static_fp8_quant( | ||
| torch::Tensor& out, | ||
| torch::Tensor& input, | ||
| torch::Tensor& residual, | ||
| torch::Tensor& weight, | ||
| torch::Tensor& scale, |
There was a problem hiding this comment.
These new APIs take all tensors by non-const reference (torch::Tensor&), which implies they may be mutated. If input, weight, and scale are read-only, prefer const torch::Tensor& for those parameters (matching existing quantization APIs like static_scaled_fp8_quant below in this header) and keep the mutability consistent with the Torch schema.
| torch::Tensor& input, | |
| torch::Tensor& weight, | |
| torch::Tensor& scale, | |
| double epsilon); | |
| void fused_add_rms_norm_static_fp8_quant( | |
| torch::Tensor& out, | |
| torch::Tensor& input, | |
| torch::Tensor& residual, | |
| torch::Tensor& weight, | |
| torch::Tensor& scale, | |
| const torch::Tensor& input, | |
| const torch::Tensor& weight, | |
| const torch::Tensor& scale, | |
| double epsilon); | |
| void fused_add_rms_norm_static_fp8_quant( | |
| torch::Tensor& out, | |
| const torch::Tensor& input, | |
| const torch::Tensor& residual, | |
| const torch::Tensor& weight, | |
| const torch::Tensor& scale, |
| set(VLLM_EXT_SRC | ||
| "csrc/cache.cpp" | ||
| "csrc/layernorm.cpp" | ||
| "csrc/layernorm_quant.cpp" |
There was a problem hiding this comment.
CMakeLists.txt adds csrc/layernorm_quant.cpp to VLLM_EXT_SRC, but that file does not exist in the repository (and rms_norm_static_fp8_quant / fused_add_rms_norm_static_fp8_quant have no definitions elsewhere under csrc/). This will break the build/link step; add the missing source file or remove the entry until the implementation is included.
| "csrc/layernorm_quant.cpp" |
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| NUM_TOKENS = [1, 7, 83, 256] | ||
| HIDDEN_SIZES = [8, 64, 768, 5120, 8192] | ||
| SEEDS = [0] | ||
| EPSILONS = [1e-5, 1e-6] |
There was a problem hiding this comment.
EPSILONS is defined but not used anywhere in this test module. Either parameterize tests over it or remove it to keep the test inputs clear.
| EPSILONS = [1e-5, 1e-6] |
| bool can_vec = | ||
| ((addr_in & (WIDTH - 1)) == 0) && ((hidden_size & (VEC_SIZE - 1)) == 0); | ||
|
|
||
| if (can_vec) { | ||
| int64_t const num_vec_elems = hidden_size / VEC_SIZE; | ||
| auto const* vec_in = reinterpret_cast<const vec_t*>(input_row); | ||
| for (int i = item_ct1.get_local_id(2); i < num_vec_elems; | ||
| i += item_ct1.get_local_range(2)) { | ||
| vec_t tmp = vec_in[i]; | ||
| for (int j = 0; j < VEC_SIZE; ++j) { | ||
| float x = static_cast<float>(tmp.val[j]); | ||
| variance += x * x; | ||
| } | ||
| } | ||
| } else { | ||
| for (int i = item_ct1.get_local_id(2); i < hidden_size; | ||
| i += item_ct1.get_local_range(2)) { | ||
| float x = static_cast<float>(input_row[i]); | ||
| variance += x * x; | ||
| } | ||
| } | ||
|
|
||
| // ---------- Phase 2: work-group reduction ---------- | ||
| variance = sycl::reduce_over_group( | ||
| sycl::ext::oneapi::this_work_item::get_work_group<3>(), | ||
| variance, | ||
| sycl::plus<>()); | ||
| if (item_ct1.get_local_id(2) == 0) { | ||
| *s_variance_ptr = sycl::rsqrt(variance / hidden_size + epsilon); | ||
| } | ||
| item_ct1.barrier(sycl::access::fence_space::local_space); | ||
|
|
||
| // ---------- Phase 3: normalize + quantize ---------- | ||
| // invert scale to avoid division | ||
| float const scale_inv = 1.0f / (*scale); | ||
| fp8::ConvertWithScaleOp<true, fp8_type> convert_op{scale_inv}; | ||
|
|
||
| if (can_vec) { | ||
| auto* v_in = reinterpret_cast<const vec_t*>(input_row); | ||
| auto* v_w = reinterpret_cast<const vec_t*>(weight); | ||
| int64_t const num_vec_elems = hidden_size / VEC_SIZE; | ||
| float s_var = *s_variance_ptr; |
There was a problem hiding this comment.
can_vec only checks alignment of input_row, but Phase 3 unconditionally reinterprets weight as aligned_vec when can_vec is true. If weight is not sufficiently aligned, this reinterpret_cast path can perform misaligned vector loads (undefined behavior). Consider including weight (and optionally out) alignment in the vectorization predicate, or falling back to scalar loads when weight is misaligned.
| for (int j = 0; j < VEC_SIZE; j++) { | ||
| float x = static_cast<float>(src1.val[j]); | ||
| float out_norm = | ||
| static_cast<float>(static_cast<scalar_t>(x * s_var)) * | ||
| static_cast<float>(src2.val[j]); | ||
| fp8_type dst; | ||
| convert_op(dst, out_norm); | ||
| out[item_ct1.get_group(2) * hidden_size + idx * VEC_SIZE + j] = dst; | ||
| } | ||
| } | ||
| } else { | ||
| float s_var = *s_variance_ptr; | ||
| for (int idx = item_ct1.get_local_id(2); idx < hidden_size; | ||
| idx += item_ct1.get_local_range(2)) { | ||
| float x = static_cast<float>(input_row[idx]); | ||
| float out_norm = static_cast<float>(static_cast<scalar_t>(x * s_var)) * | ||
| static_cast<float>(weight[idx]); | ||
| fp8_type dst; |
There was a problem hiding this comment.
The normalization path computes out_norm in float after only casting the normalized value back to scalar_t (but not the subsequent multiply by weight). The test reference explicitly rounds at input.dtype before applying FP8 quantization; to match that behavior (and the cast order used in csrc/layernorm.cpp), the multiply by weight should be performed in scalar_t and only then converted to float for quantization.
| double epsilon) { | ||
| TORCH_CHECK(out.is_contiguous()); | ||
| int hidden_size = input.size(-1); |
There was a problem hiding this comment.
The host entry points dereference scale.data_ptr<float>() and compute input_stride = input.stride(-2) without validating assumptions. Please add input validation similar to rms_norm (e.g., ensure input.dim() >= 2, input.stride(-1) == 1 or make input = input.contiguous(), weight.is_contiguous(), and scale is contiguous float with numel() == 1). Without these checks, passing a transposed/non-contiguous-last-dim input or a non-float/incorrectly-shaped scale can lead to incorrect results or crashes.
| double epsilon) { | |
| TORCH_CHECK(out.is_contiguous()); | |
| int hidden_size = input.size(-1); | |
| double epsilon) { | |
| // Validate tensor layouts and shapes to match kernel assumptions. | |
| TORCH_CHECK( | |
| input.dim() >= 2, | |
| "rms_norm_static_fp8_quant: input must have at least 2 dimensions, but got ", | |
| input.dim()); | |
| TORCH_CHECK( | |
| input.stride(-1) == 1, | |
| "rms_norm_static_fp8_quant: last dimension of input must be contiguous " | |
| "(stride(-1) == 1), but got stride(-1) = ", | |
| input.stride(-1)); | |
| TORCH_CHECK( | |
| weight.is_contiguous(), | |
| "rms_norm_static_fp8_quant: weight must be contiguous."); | |
| TORCH_CHECK( | |
| weight.dim() == 1, | |
| "rms_norm_static_fp8_quant: weight must be 1D, but got dim = ", | |
| weight.dim()); | |
| TORCH_CHECK( | |
| scale.is_contiguous(), | |
| "rms_norm_static_fp8_quant: scale must be contiguous."); | |
| TORCH_CHECK( | |
| scale.scalar_type() == torch::kFloat32, | |
| "rms_norm_static_fp8_quant: scale must have dtype float32, but got ", | |
| scale.scalar_type()); | |
| TORCH_CHECK( | |
| scale.numel() == 1, | |
| "rms_norm_static_fp8_quant: scale must contain exactly one element, but got ", | |
| scale.numel()); | |
| TORCH_CHECK(out.is_contiguous()); | |
| int hidden_size = input.size(-1); | |
| TORCH_CHECK( | |
| hidden_size > 0, | |
| "rms_norm_static_fp8_quant: hidden_size (last dimension of input) must be > 0."); | |
| TORCH_CHECK( | |
| weight.numel() == hidden_size, | |
| "rms_norm_static_fp8_quant: weight.numel() (", | |
| weight.numel(), | |
| ") must match hidden_size (", | |
| hidden_size, | |
| ")."); | |
| TORCH_CHECK( | |
| input.numel() % hidden_size == 0, | |
| "rms_norm_static_fp8_quant: input.numel() must be divisible by hidden_size. " | |
| "Got input.numel() = ", | |
| input.numel(), | |
| ", hidden_size = ", | |
| hidden_size); |
| TORCH_CHECK(out.is_contiguous()); | ||
| TORCH_CHECK(residual.is_contiguous()); | ||
| TORCH_CHECK(residual.scalar_type() == input.scalar_type()); | ||
| TORCH_CHECK(weight.scalar_type() == input.scalar_type()); | ||
|
|
||
| int hidden_size = input.size(-1); | ||
| int input_stride = input.stride(-2); | ||
| int num_tokens = input.numel() / hidden_size; |
There was a problem hiding this comment.
Same as rms_norm_static_fp8_quant: input/scale/weight assumptions aren’t validated here (e.g., input.stride(-1) == 1, scale is float contiguous with 1 element, and input.dim() >= 2). Adding these checks (or making input contiguous when needed) will prevent misinterpreting scale memory and out-of-bounds/incorrect indexing for unsupported strides.
|
|
||
| if (can_vec) { | ||
| int64_t const vec_hidden = hidden_size / VEC_SIZE; | ||
| int64_t const vec_stride = input_stride / VEC_SIZE; |
There was a problem hiding this comment.
vec_stride is computed but never used, which will trigger an unused-variable warning in some builds. It should either be used (if intended for strided row access) or removed.
| int64_t const vec_stride = input_stride / VEC_SIZE; |
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
add layer norm quant fusion kernel
refer https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_quant_kernels.cu
it's done by Copilt
Test Plan
pytest -v -s tests/test_layernorm_quant.py
Test Result
pass
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)