Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
#endif
}

extern void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using extern for function declarations across translation units is risky and can lead to subtle bugs if the function signature changes, as mismatches are only caught at link time. It's much safer and better practice to declare silu_and_mul in a header file (e.g., csrc/cpu/activation.h) and include that header here. This ensures type safety and improves code maintainability.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

silu_and_mul function is declared in csrc/cpu/op.h, however there is a function definition in this header file, include this header causes multiple definition error.


enum ActivationKind : int64_t {
SwiGLU_Gu = 0, // act = SiLU(g) * u
SwiGLUOAI = 1, // act = SiLU(u) * g
Expand Down Expand Up @@ -87,30 +89,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;

// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
if (apply_router_weight_on_input) {
X_all = X_all.mul(expert_gates.unsqueeze(1));
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());

at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
c10::InferenceMode guard;
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
continue;
}

const int64_t start = offsets[e];

auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);

auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);

if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);

auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
Expand All @@ -119,35 +114,33 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
auto y13 =
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);

auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);

torch::Tensor act;
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto gate_c = at::clamp_max(g_part, kLimit);
auto up_c = at::clamp(u_part, -kLimit, kLimit);
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
Comment on lines 118 to 126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This pull request focuses on improving MoE performance on CPU, and correctly introduces the silu_and_mul fused kernel for the SiLU activation path. However, the SwiGLUOAI activation path remains a sequence of multiple separate PyTorch operations. This will result in a significant performance discrepancy between the two activation paths, undermining the overall performance goal. To ensure consistent high performance, a fused CPU kernel for SwiGLUOAI should be implemented, similar to silu_and_mul.

} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act = at::silu(g_part).mul(u_part);
act = at::empty({te, I}, y13.options());
silu_and_mul(act, y13);
Comment on lines 125 to +129

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard fused SiLU path for non-SIMD-aligned widths

The new branch replaces at::silu(...).mul(...) with silu_and_mul(act, y13). The CPU implementation of silu_and_mul (csrc/cpu/activation.cpp) asserts d % VEC_ELEM_NUM == 0, i.e. the intermediate dimension must be divisible by the SIMD width (8 for fp32, 16 for fp16/bf16). The previous code worked for any I, so experts whose hidden size yields an odd intermediate (e.g. I=1530) will now hit this TORCH_CHECK and the op will abort instead of computing the activation. A fallback to the unfused SiLU*mul path is needed for shapes that are not vector-aligned.

Useful? React with 👍 / 👎.

}

// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);

if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}

// Store per-expert result
y_list[e] = y;
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
}
});

// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
if (!apply_router_weight_on_input) {
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
}

auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
Expand Down
Loading