Skip to content

[FusedKernel]add layer norm quant fusion kernel#180

Open
jikunshang wants to merge 4 commits intovllm-project:mainfrom
jikunshang:kunshang/fuse_rms_norm_quant
Open

[FusedKernel]add layer norm quant fusion kernel#180
jikunshang wants to merge 4 commits intovllm-project:mainfrom
jikunshang:kunshang/fuse_rms_norm_quant

Conversation

@jikunshang
Copy link
Collaborator

@jikunshang jikunshang commented Mar 6, 2026

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

add layer norm quant fusion kernel
refer https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_quant_kernels.cu
it's done by Copilt

Test Plan

pytest -v -s tests/test_layernorm_quant.py

Test Result

pass

(Optional) Documentation Update

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Copilot AI review requested due to automatic review settings March 6, 2026 02:05
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds XPU/SYCL fused RMSNorm + static FP8 quantization ops and corresponding unit tests, wiring the new kernels into the C++ extension build and Torch op registry.

Changes:

  • Registers two new Torch ops: rms_norm_static_fp8_quant and fused_add_rms_norm_static_fp8_quant.
  • Adds unit tests validating correctness, strided inputs, and opcheck compliance.
  • Updates the extension build sources to include a new layernorm_quant.cpp compilation unit.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
tests/test_layernorm_quant.py New test suite for the fused RMSNorm + static FP8 quant ops (incl. opcheck + strided input coverage).
csrc/torch_bindings.cpp Registers the two new ops in the _C Torch library for XPU.
csrc/ops.h Declares the new kernel entry points.
CMakeLists.txt Adds csrc/layernorm_quant.cpp to the _C extension sources.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +40 to +48
"rms_norm_static_fp8_quant(Tensor! result, Tensor! input, "
"Tensor! weight, Tensor! scale, float epsilon) -> ()");
ops.impl(
"rms_norm_static_fp8_quant", torch::kXPU, &rms_norm_static_fp8_quant);

// Fused Add + RMS Normalization + static FP8 quantization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor! input, "
"Tensor! residual, Tensor! weight, Tensor! scale, float epsilon) -> ()");
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op schema marks input, weight, and scale as mutable (Tensor!). Unless the kernel actually mutates these tensors, they should be registered as immutable (Tensor) to avoid incorrect aliasing/mutation semantics (and to align with nearby FP8 quant ops like static_scaled_fp8_quant(Tensor input, Tensor scale, ...)). Consider also updating the C++ signature to take const Tensor& for read-only args.

Suggested change
"rms_norm_static_fp8_quant(Tensor! result, Tensor! input, "
"Tensor! weight, Tensor! scale, float epsilon) -> ()");
ops.impl(
"rms_norm_static_fp8_quant", torch::kXPU, &rms_norm_static_fp8_quant);
// Fused Add + RMS Normalization + static FP8 quantization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor! input, "
"Tensor! residual, Tensor! weight, Tensor! scale, float epsilon) -> ()");
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor scale, float epsilon) -> ()");
ops.impl(
"rms_norm_static_fp8_quant", torch::kXPU, &rms_norm_static_fp8_quant);
// Fused Add + RMS Normalization + static FP8 quantization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor residual, Tensor weight, Tensor scale, float epsilon) -> ()");

Copilot uses AI. Check for mistakes.
Comment on lines +46 to +52
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor! input, "
"Tensor! residual, Tensor! weight, Tensor! scale, float epsilon) -> ()");
ops.impl(
"fused_add_rms_norm_static_fp8_quant",
torch::kXPU,
&fused_add_rms_norm_static_fp8_quant);
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, input, weight, and scale are registered as mutable (Tensor!). If only residual (and result) are written, the schema should reflect that (e.g., Tensor input, Tensor! residual, Tensor weight, Tensor scale, ...) so functionalization/opcheck and future compiler passes see accurate mutation behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +31
torch::Tensor& input,
torch::Tensor& weight,
torch::Tensor& scale,
double epsilon);

void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale,
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new APIs take all tensors by non-const reference (torch::Tensor&), which implies they may be mutated. If input, weight, and scale are read-only, prefer const torch::Tensor& for those parameters (matching existing quantization APIs like static_scaled_fp8_quant below in this header) and keep the mutability consistent with the Torch schema.

Suggested change
torch::Tensor& input,
torch::Tensor& weight,
torch::Tensor& scale,
double epsilon);
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale,
const torch::Tensor& input,
const torch::Tensor& weight,
const torch::Tensor& scale,
double epsilon);
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out,
const torch::Tensor& input,
const torch::Tensor& residual,
const torch::Tensor& weight,
const torch::Tensor& scale,

Copilot uses AI. Check for mistakes.
set(VLLM_EXT_SRC
"csrc/cache.cpp"
"csrc/layernorm.cpp"
"csrc/layernorm_quant.cpp"
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMakeLists.txt adds csrc/layernorm_quant.cpp to VLLM_EXT_SRC, but that file does not exist in the repository (and rms_norm_static_fp8_quant / fused_add_rms_norm_static_fp8_quant have no definitions elsewhere under csrc/). This will break the build/link step; add the missing source file or remove the entry until the implementation is included.

Suggested change
"csrc/layernorm_quant.cpp"

Copilot uses AI. Check for mistakes.
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

NUM_TOKENS = [1, 7, 83, 256]
HIDDEN_SIZES = [8, 64, 768, 5120, 8192]
SEEDS = [0]
EPSILONS = [1e-5, 1e-6]
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EPSILONS is defined but not used anywhere in this test module. Either parameterize tests over it or remove it to keep the test inputs clear.

Suggested change
EPSILONS = [1e-5, 1e-6]

Copilot uses AI. Check for mistakes.
Comment on lines +65 to +106
bool can_vec =
((addr_in & (WIDTH - 1)) == 0) && ((hidden_size & (VEC_SIZE - 1)) == 0);

if (can_vec) {
int64_t const num_vec_elems = hidden_size / VEC_SIZE;
auto const* vec_in = reinterpret_cast<const vec_t*>(input_row);
for (int i = item_ct1.get_local_id(2); i < num_vec_elems;
i += item_ct1.get_local_range(2)) {
vec_t tmp = vec_in[i];
for (int j = 0; j < VEC_SIZE; ++j) {
float x = static_cast<float>(tmp.val[j]);
variance += x * x;
}
}
} else {
for (int i = item_ct1.get_local_id(2); i < hidden_size;
i += item_ct1.get_local_range(2)) {
float x = static_cast<float>(input_row[i]);
variance += x * x;
}
}

// ---------- Phase 2: work-group reduction ----------
variance = sycl::reduce_over_group(
sycl::ext::oneapi::this_work_item::get_work_group<3>(),
variance,
sycl::plus<>());
if (item_ct1.get_local_id(2) == 0) {
*s_variance_ptr = sycl::rsqrt(variance / hidden_size + epsilon);
}
item_ct1.barrier(sycl::access::fence_space::local_space);

// ---------- Phase 3: normalize + quantize ----------
// invert scale to avoid division
float const scale_inv = 1.0f / (*scale);
fp8::ConvertWithScaleOp<true, fp8_type> convert_op{scale_inv};

if (can_vec) {
auto* v_in = reinterpret_cast<const vec_t*>(input_row);
auto* v_w = reinterpret_cast<const vec_t*>(weight);
int64_t const num_vec_elems = hidden_size / VEC_SIZE;
float s_var = *s_variance_ptr;
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can_vec only checks alignment of input_row, but Phase 3 unconditionally reinterprets weight as aligned_vec when can_vec is true. If weight is not sufficiently aligned, this reinterpret_cast path can perform misaligned vector loads (undefined behavior). Consider including weight (and optionally out) alignment in the vectorization predicate, or falling back to scalar loads when weight is misaligned.

Copilot uses AI. Check for mistakes.
Comment on lines +112 to +129
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float out_norm =
static_cast<float>(static_cast<scalar_t>(x * s_var)) *
static_cast<float>(src2.val[j]);
fp8_type dst;
convert_op(dst, out_norm);
out[item_ct1.get_group(2) * hidden_size + idx * VEC_SIZE + j] = dst;
}
}
} else {
float s_var = *s_variance_ptr;
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
idx += item_ct1.get_local_range(2)) {
float x = static_cast<float>(input_row[idx]);
float out_norm = static_cast<float>(static_cast<scalar_t>(x * s_var)) *
static_cast<float>(weight[idx]);
fp8_type dst;
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normalization path computes out_norm in float after only casting the normalized value back to scalar_t (but not the subsequent multiply by weight). The test reference explicitly rounds at input.dtype before applying FP8 quantization; to match that behavior (and the cast order used in csrc/layernorm.cpp), the multiply by weight should be performed in scalar_t and only then converted to float for quantization.

Copilot uses AI. Check for mistakes.
Comment on lines +310 to +312
double epsilon) {
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The host entry points dereference scale.data_ptr<float>() and compute input_stride = input.stride(-2) without validating assumptions. Please add input validation similar to rms_norm (e.g., ensure input.dim() >= 2, input.stride(-1) == 1 or make input = input.contiguous(), weight.is_contiguous(), and scale is contiguous float with numel() == 1). Without these checks, passing a transposed/non-contiguous-last-dim input or a non-float/incorrectly-shaped scale can lead to incorrect results or crashes.

Suggested change
double epsilon) {
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
double epsilon) {
// Validate tensor layouts and shapes to match kernel assumptions.
TORCH_CHECK(
input.dim() >= 2,
"rms_norm_static_fp8_quant: input must have at least 2 dimensions, but got ",
input.dim());
TORCH_CHECK(
input.stride(-1) == 1,
"rms_norm_static_fp8_quant: last dimension of input must be contiguous "
"(stride(-1) == 1), but got stride(-1) = ",
input.stride(-1));
TORCH_CHECK(
weight.is_contiguous(),
"rms_norm_static_fp8_quant: weight must be contiguous.");
TORCH_CHECK(
weight.dim() == 1,
"rms_norm_static_fp8_quant: weight must be 1D, but got dim = ",
weight.dim());
TORCH_CHECK(
scale.is_contiguous(),
"rms_norm_static_fp8_quant: scale must be contiguous.");
TORCH_CHECK(
scale.scalar_type() == torch::kFloat32,
"rms_norm_static_fp8_quant: scale must have dtype float32, but got ",
scale.scalar_type());
TORCH_CHECK(
scale.numel() == 1,
"rms_norm_static_fp8_quant: scale must contain exactly one element, but got ",
scale.numel());
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
TORCH_CHECK(
hidden_size > 0,
"rms_norm_static_fp8_quant: hidden_size (last dimension of input) must be > 0.");
TORCH_CHECK(
weight.numel() == hidden_size,
"rms_norm_static_fp8_quant: weight.numel() (",
weight.numel(),
") must match hidden_size (",
hidden_size,
").");
TORCH_CHECK(
input.numel() % hidden_size == 0,
"rms_norm_static_fp8_quant: input.numel() must be divisible by hidden_size. "
"Got input.numel() = ",
input.numel(),
", hidden_size = ",
hidden_size);

Copilot uses AI. Check for mistakes.
Comment on lines +358 to +365
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());

int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as rms_norm_static_fp8_quant: input/scale/weight assumptions aren’t validated here (e.g., input.stride(-1) == 1, scale is float contiguous with 1 element, and input.dim() >= 2). Adding these checks (or making input contiguous when needed) will prevent misinterpreting scale memory and out-of-bounds/incorrect indexing for unsupported strides.

Copilot uses AI. Check for mistakes.

if (can_vec) {
int64_t const vec_hidden = hidden_size / VEC_SIZE;
int64_t const vec_stride = input_stride / VEC_SIZE;
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vec_stride is computed but never used, which will trigger an unused-variable warning in some builds. It should either be used (if intended for strided row access) or removed.

Suggested change
int64_t const vec_stride = input_stride / VEC_SIZE;

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants