Skip to content

Conversation

xiangze-arm
Copy link

@xiangze-arm xiangze-arm commented Oct 21, 2025

  • Avoid tensor concat
  • Use silu_and_mul kernel

Purpose

Improve dynamic 4bit moe performance on Arm CPU

Test Plan

Tested locally with Qwen3-30B-A3B w4a8 model quantized with llmcompressor

Test Result

See performance improvements, especially for small batch sizes.

- Avoid tensor concat
- Use silu_and_mul kernel

Signed-off-by: Zhang Xiangze <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for the dynamic 4-bit Mixture of Experts (MoE) implementation on CPU. The key changes involve pre-gathering token data and pre-allocating output tensors to avoid costly index_select and cat operations within the parallel loop, which is a substantial improvement. Additionally, it leverages a fused silu_and_mul kernel for the SiLU activation path. My review identifies two high-severity issues. First, the use of an extern function declaration without a corresponding header file poses a maintainability and correctness risk. Second, while the SiLU path is optimized with a fused kernel, the SwiGLUOAI activation path is not, creating a performance inconsistency that should be addressed to fully realize the performance goals of this PR. The rest of the changes are well-implemented and contribute positively to the performance.

#endif
}

extern void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using extern for function declarations across translation units is risky and can lead to subtle bugs if the function signature changes, as mismatches are only caught at link time. It's much safer and better practice to declare silu_and_mul in a header file (e.g., csrc/cpu/activation.h) and include that header here. This ensures type safety and improves code maintainability.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

silu_and_mul function is declared in csrc/cpu/op.h, however there is a function definition in this header file, include this header causes multiple definition error.

Comment on lines 118 to 126
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto gate_c = at::clamp_max(g_part, kLimit);
auto up_c = at::clamp(u_part, -kLimit, kLimit);
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This pull request focuses on improving MoE performance on CPU, and correctly introduces the silu_and_mul fused kernel for the SiLU activation path. However, the SwiGLUOAI activation path remains a sequence of multiple separate PyTorch operations. This will result in a significant performance discrepancy between the two activation paths, undermining the overall performance goal. To ensure consistent high performance, a fused CPU kernel for SwiGLUOAI should be implemented, similar to silu_and_mul.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 125 to +129
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act = at::silu(g_part).mul(u_part);
act = at::empty({te, I}, y13.options());
silu_and_mul(act, y13);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard fused SiLU path for non-SIMD-aligned widths

The new branch replaces at::silu(...).mul(...) with silu_and_mul(act, y13). The CPU implementation of silu_and_mul (csrc/cpu/activation.cpp) asserts d % VEC_ELEM_NUM == 0, i.e. the intermediate dimension must be divisible by the SIMD width (8 for fp32, 16 for fp16/bf16). The previous code worked for any I, so experts whose hidden size yields an odd intermediate (e.g. I=1530) will now hit this TORCH_CHECK and the op will abort instead of computing the activation. A fallback to the unfused SiLU*mul path is needed for shapes that are not vector-aligned.

Useful? React with 👍 / 👎.

Signed-off-by: Zhang Xiangze <[email protected]>
@xiangze-arm
Copy link
Author

@mgoin @bigPYJ1151 Can you help to review this PR?

cc @nikhil-arm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant