diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index 1d06fc6b5b0a..0b2c524b2cb2 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -21,6 +21,8 @@ inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w, #endif } +extern void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + enum ActivationKind : int64_t { SwiGLU_Gu = 0, // act = SiLU(g) * u SwiGLUOAI = 1, // act = SiLU(u) * g @@ -87,30 +89,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu( const int64_t g_eff_13 = (group_size != -1) ? group_size : H; const int64_t g_eff_2 = (group_size != -1) ? group_size : I; - // Per-expert outputs filled in parallel - std::vector y_list(E); - y_list.resize(E); + auto X_all = x_c.index_select(/*dim=*/0, expert_tokens); + if (apply_router_weight_on_input) { + X_all = X_all.mul(expert_gates.unsqueeze(1)); + } + auto Y_all = at::empty({offsets[E], H}, x_c.options()); at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + c10::InferenceMode guard; for (int64_t e = e_begin; e < e_end; ++e) { const int64_t te = counts[e]; if (te == 0) { - y_list[e] = at::empty({0, H}, x_c.options()); continue; } const int64_t start = offsets[e]; - auto sel_tokens = - expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); - auto gates_e = - expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); - - auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); - - if (apply_router_weight_on_input) { - x_e = x_e.mul(gates_e.unsqueeze(1)); - } + auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e); auto w2_e = w2_packed.select(/*dim=*/0, e); @@ -119,35 +114,33 @@ torch::Tensor dynamic_4bit_int_moe_cpu( auto y13 = mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2); - auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I); - auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I); - torch::Tensor act; if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI - constexpr double kAlpha = 1.702; // GPT-OSS default - constexpr double kLimit = 7.0; // GPT-OSS default + auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I); + auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I); + constexpr double kAlpha = 1.702; // GPT-OSS default + constexpr double kLimit = 7.0; // GPT-OSS default auto gate_c = at::clamp_max(g_part, kLimit); auto up_c = at::clamp(u_part, -kLimit, kLimit); auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha))); act = up_c.add(1.0).mul(glu); } else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul() - act = at::silu(g_part).mul(u_part); + act = at::empty({te, I}, y13.options()); + silu_and_mul(act, y13); } // W2 auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); - if (!apply_router_weight_on_input) { - y = y.mul(gates_e.unsqueeze(1)); - } - // Store per-expert result - y_list[e] = y; + Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y); } }); - // Concatenate all expert outputs to match expert_tokens order - auto Y_all = at::cat(y_list, /*dim=*/0); + if (!apply_router_weight_on_input) { + Y_all = Y_all.mul(expert_gates.unsqueeze(1)); + } + auto out = at::zeros({T, H}, x.options()); out = at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);