Skip to content

Commit 1605f6b

Browse files
authored
Enable Model Config Input via a Centralized Parser in utils.py (#13)
* refine kernel benchmark structure * refine * format fix * enable native pytorch op path * fix format * refine kernellevel benchmarking for topk * Update bench_moe_topk_softmax.py * Update bench_moe_topk_softmax.py * Update pr-test-xpu.yml * format fix
1 parent 688c0b8 commit 1605f6b

File tree

3 files changed

+330
-47
lines changed

3 files changed

+330
-47
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
timeout-minutes: 20
5656
run: |
5757
docker exec -w /root/sglang ci_sglang_xpu \
58-
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py "
58+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py "
5959
6060
- name: Run E2E Bfloat16 tests
6161
timeout-minutes: 20

benchmark/bench_moe_topk_softmax.py

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import triton
55
from sgl_kernel import topk_softmax
6+
from utils import get_model_config, parse_args
67

78

89
def vllm_topk_softmax(gating_output, topk):
@@ -23,7 +24,35 @@ def vllm_topk_softmax(gating_output, topk):
2324
return topk_weights, topk_indices
2425

2526

26-
def sglang_topk_softmax(gating_output, topk):
27+
def navtive_topk_softmax(
28+
gating_output: torch.Tensor,
29+
topk: int,
30+
renormalize: bool,
31+
):
32+
num_tokens, num_experts = gating_output.shape
33+
34+
import torch.nn.functional as F
35+
36+
topk_weights = torch.empty(
37+
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
38+
)
39+
topk_indices = torch.empty(
40+
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
41+
)
42+
topk_weights = F.softmax(gating_output.float(), dim=-1)
43+
topk_weights, topk_indices = torch.topk(topk_weights, topk, dim=-1)
44+
45+
if renormalize:
46+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
47+
48+
return topk_weights, topk_indices
49+
50+
51+
def sglang_topk_softmax(
52+
gating_output: torch.Tensor,
53+
topk: int,
54+
renormalize: bool,
55+
):
2756
num_tokens, num_experts = gating_output.shape
2857

2958
topk_weights = torch.empty(
@@ -37,18 +66,18 @@ def sglang_topk_softmax(gating_output, topk):
3766
)
3867

3968
topk_softmax(
40-
topk_weights=topk_weights,
41-
topk_ids=topk_indices,
42-
token_expert_indices=token_expert_indices,
43-
gating_output=gating_output,
69+
topk_weights,
70+
topk_indices,
71+
gating_output,
72+
renormalize=renormalize,
4473
)
4574

4675
return topk_weights, topk_indices
4776

4877

4978
def calculate_diff(num_tokens, num_experts, topk):
5079
gating_output = torch.randn(
51-
(num_tokens, num_experts), device="cuda", dtype=torch.float32
80+
(num_tokens, num_experts), device=gating_output.device, dtype=torch.float32
5281
)
5382
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
5483
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
@@ -67,52 +96,67 @@ def calculate_diff(num_tokens, num_experts, topk):
6796
)
6897

6998

70-
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
71-
num_experts_range = [32, 64, 128, 256, 12, 512]
72-
topk_range = [1, 2, 4, 8]
99+
def get_benchmark(device="xpu"):
100+
@triton.testing.perf_report(
101+
triton.testing.Benchmark(
102+
x_names=["num_tokens", "num_experts", "topk", "dtype", "renormalize"],
103+
x_vals=configs,
104+
line_arg="provider",
105+
line_vals=["sglang", "native"],
106+
line_names=["SGLang", "native"],
107+
styles=[("blue", "-"), ("green", "-")],
108+
ylabel="Latency (us)",
109+
plot_name="topk-softmax-performance",
110+
args={},
111+
)
112+
)
113+
def benchmark(num_tokens, num_experts, topk, dtype, renormalize, provider):
73114

74-
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
115+
gating_output = torch.randn(
116+
(num_tokens, num_experts), device=device, dtype=dtype
117+
)
75118

119+
if provider == "sglang" or provider == "sglang1":
120+
fn = lambda: sglang_topk_softmax(gating_output, topk, renormalize)
121+
elif provider == "native":
122+
fn = lambda: navtive_topk_softmax(gating_output, topk, renormalize)
76123

77-
@triton.testing.perf_report(
78-
triton.testing.Benchmark(
79-
x_names=["num_tokens", "num_experts", "topk"],
80-
x_vals=configs,
81-
line_arg="provider",
82-
line_vals=["sglang", "vllm"],
83-
line_names=["SGLang", "VLLM"],
84-
styles=[("blue", "-"), ("green", "-")],
85-
ylabel="Latency (us)",
86-
plot_name="topk-softmax-performance",
87-
args={},
88-
)
89-
)
90-
def benchmark(num_tokens, num_experts, topk, provider):
124+
quantiles = [0.5, 0.2, 0.8]
125+
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
91126

92-
gating_output = torch.randn(
93-
(num_tokens, num_experts), device="cuda", dtype=torch.float32
94-
)
127+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
128+
129+
return benchmark
95130

96-
if provider == "vllm" or provider == "vllm1":
97-
fn = lambda: vllm_topk_softmax(gating_output, topk)
98-
elif provider == "sglang" or provider == "sglang1":
99-
fn = lambda: sglang_topk_softmax(gating_output, topk)
100131

101-
quantiles = [0.5, 0.2, 0.8]
102-
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
132+
if __name__ == "__main__":
133+
# Run correctness test on small configs if not using a real model
134+
args = parse_args()
135+
params = get_model_config(args)
136+
137+
sweep_params = {
138+
"num_tokens": args.num_tokens,
139+
"num_experts": params["num_experts"] or [64],
140+
"top_k": params["top_k"] or [2, 4],
141+
"dtype": [torch.bfloat16],
142+
"renormalize": [False],
143+
}
144+
145+
keys = sweep_params.keys()
146+
configs = list(itertools.product(*sweep_params.values()))
147+
print(f"Testing {len(configs)} configurations...")
148+
for config in configs:
149+
num_tokens, num_experts, topk, dtype, renormalize = config
150+
print(
151+
f"Config: num_tokens={num_tokens}, num_experts={num_experts}, topk={topk}, dtype={dtype}, renormalize={renormalize}"
152+
)
103153

104-
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
154+
# calculate_diff(num_tokens, num_experts, topk)
105155

156+
global benchmark_configs
157+
benchmark_configs = configs
106158

107-
if __name__ == "__main__":
108-
configs = [
109-
(20, 256, 4),
110-
(20, 256, 8),
111-
(20, 12, 4),
112-
(20, 12, 1),
113-
(20, 512, 4),
114-
(20, 512, 1),
115-
]
116-
for num_tokens, num_experts, topk in configs:
117-
calculate_diff(num_tokens, num_experts, topk)
118-
benchmark.run(print_data=True)
159+
# Run benchmark
160+
print("Starting performance benchmark...")
161+
benchmark = get_benchmark()
162+
benchmark.run(print_data=True, show_plots=False, save_path=".")

0 commit comments

Comments
 (0)