Skip to content

refine bf16 group gemm dispatch policy#177

Open
xinyu-intel wants to merge 1 commit intovllm-project:mainfrom
xinyu-intel:dev/group-gemm-bf16-policy
Open

refine bf16 group gemm dispatch policy#177
xinyu-intel wants to merge 1 commit intovllm-project:mainfrom
xinyu-intel:dev/group-gemm-bf16-policy

Conversation

@xinyu-intel
Copy link
Collaborator

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

To refine the dispatch policy of bf16 group gemm. Especially for the cases M and N are small.

Test Plan

Script:

# SPDX-License-Identifier: Apache-2.0
import torch

import random
import numpy as np

from vllm_xpu_kernels.fused_moe_interface import cutlass_grouped_gemm_xe2

DEVICE = "xpu"


def seed_everything(seed) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def init_rows_for_experts(tokens, topk, num_rows_per_expert):
    if num_rows_per_expert.shape[0] == 1:
        num_rows_per_expert[0] = tokens * topk
        return
    n_experts = num_rows_per_expert.numel()
    rand = torch.rand(tokens, n_experts, device=num_rows_per_expert.device)
    topk_idx = torch.topk(rand, topk, dim=1).indices  # [tokens, topk]
    flat_idx = topk_idx.flatten()
    num_rows_per_expert += torch.bincount(flat_idx, minlength=n_experts)


def test_xe_grouped_gemm(m, n, k, e, topk, dtype, has_bias):
    seed_everything(7)
    num_experts = e
    total_m = m * topk
    # input
    input_A = torch.randn((total_m, k), dtype=dtype,
                          device=DEVICE).contiguous()
    ref_A = input_A
    # weight
    input_B = torch.randn((num_experts, k, n), dtype=dtype, device=DEVICE)
    if has_bias:
        bias = torch.randn((num_experts, n), dtype=dtype, device=DEVICE)
    else:
        bias = None

    # output offset
    num_rows_per_expert = torch.zeros(num_experts,
                                      device=DEVICE,
                                      dtype=torch.int32)
    init_rows_for_experts(m, topk, num_rows_per_expert)
    output = torch.empty((total_m, n), dtype=dtype, device=DEVICE)

    cutlass_grouped_gemm_xe2(input_A, input_B, None, bias, output,
                             num_rows_per_expert, n, k, num_experts, False,
                             False)

    # ref gg
    ref = []
    pre_token_sum = 0
    for i in range(num_experts):
        cur_token_num = num_rows_per_expert[i]
        if cur_token_num == 0:
            continue
        input = ref_A[pre_token_sum:pre_token_sum + cur_token_num, :].to(
            torch.float32)
        weight = input_B[i, :, :].to(torch.float32)
        expert_output_fp32 = input @ weight
        if has_bias:
            expert_output_fp32 += bias[i]
        ref.append(expert_output_fp32.to(dtype))
        pre_token_sum += cur_token_num
    ref = torch.cat(ref, dim=0)

    torch.testing.assert_close(output, ref, rtol=2e-2, atol=1e-2)
    
    iters = 100
    startEvent = torch.Event(enable_timing=True)
    endEvent = torch.Event(enable_timing=True)
    startEvent.record()
    for i in range(iters):
        cutlass_grouped_gemm_xe2(input_A, input_B, None, bias, output,
                                num_rows_per_expert, n, k, num_experts, False,
                                False)
    endEvent.record()
    torch.accelerator.synchronize()
    print(f"Average latency: {startEvent.elapsed_time(endEvent) * 1000 / iters:.2f} us")


if __name__ == "__main__":
    print("Testing grouped GEMM on Xe...")
    print("Testing Qwen3-30B-A3B-Instruct with MNK factors (80, 768 * 2 // 4, 2048), num_experts=128, topk=8")
    test_xe_grouped_gemm(80, 768 * 2 // 4, 2048, 128, 8, torch.bfloat16, False)
    print("Testing Qwen3-30B-A3B-Instruct with MNK factors (8192, 768 * 2 // 4, 2048), num_experts=128, topk=8")
    test_xe_grouped_gemm(8192, 768 * 2 // 4, 2048, 128, 8, torch.bfloat16, False)
    # print("Testing Llama-4-scout with MNK factors (30, 8192 * 2, 5120), num_experts=16, topk=1")
    # test_xe_grouped_gemm(30, 8192 * 2, 5120, 16, 1, torch.bfloat16, False)
    # print("Testing Llama-4-scout with MNK factors (8192, 8192 * 2, 5120), num_experts=16, topk=1")
    # test_xe_grouped_gemm(8192, 8192 * 2, 5120, 16, 1, torch.bfloat16, False)

Test Result

the performance comparison on B60:

Cases Token Distribution Before (us) After (us)
Qwen3-30B-A3B-Instruct with MNK factors (80, 768 * 2 // 4, 2048), num_experts=128, topk=8 Uniform 1520.04 530.11
Qwen3-30B-A3B-Instruct with MNK factors (8192, 768 * 2 // 4, 2048), num_experts=128, topk=8 Uniform 2677.47 1543.67

(Optional) Documentation Update

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Copilot AI review requested due to automatic review settings March 5, 2026 07:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refines the Xe2 bf16/FP16 grouped GEMM dispatch and kernel tiling to improve performance, especially for smaller per-expert M (and generally smaller-N workloads).

Changes:

  • Updates the w16a16 (bf16/fp16) dispatch thresholds to use specialized policies for A_avg_M <= 8/16/32.
  • Adjusts the default Xe2 GEMM policy base workgroup tile and subgroup layout (WGTile and SGLayout) to new shapes.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
csrc/xpu/grouped_gemm/xe_2/grouped_gemm_xe2_interface.hpp Refines w16a16 policy selection based on A_avg_M to better target small-M cases.
csrc/xpu/grouped_gemm/xe_2/gemm_xe2_policy.hpp Changes default policy base tiling/subgroup layout, impacting the default w16a16 kernel configuration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@jikunshang jikunshang requested a review from mayuyuace March 5, 2026 10:31
@xinyu-intel xinyu-intel mentioned this pull request Mar 6, 2026
2 tasks
@mayuyuace
Copy link
Collaborator

Please test more shapes, not only for the qwen a3b model.

@xinyu-intel
Copy link
Collaborator Author

Please test more shapes, not only for the qwen a3b model.

sure. I won't merge this before we claim no regression.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants