Skip to content

Commit fa3d7a8

Browse files
yewentao256epwalsh
authored andcommitted
[Refactor] Remove moe_align_block_size_triton (vllm-project#21335)
Signed-off-by: yewentao256 <[email protected]>
1 parent 51e0024 commit fa3d7a8

File tree

2 files changed

+6
-224
lines changed

2 files changed

+6
-224
lines changed

benchmarks/kernels/benchmark_moe_align_block_size.py

Lines changed: 4 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
import torch
77

8-
from vllm import _custom_ops as ops
98
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
10-
moe_align_block_size_triton,
9+
moe_align_block_size,
1110
)
1211
from vllm.triton_utils import triton
1312

@@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
2120
)
2221

2322

24-
def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
25-
"""
26-
Verifies vllm vs. Triton
27-
"""
28-
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
29-
30-
# 1. malloc space for triton and vllm
31-
# malloc enough space (max_num_tokens_padded) for the sorted ids
32-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
33-
sorted_ids_triton = torch.empty(
34-
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
35-
)
36-
expert_ids_triton = torch.empty(
37-
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
38-
)
39-
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
40-
41-
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
42-
expert_ids_vllm = torch.empty_like(expert_ids_triton)
43-
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
44-
45-
# 2. run implementations
46-
moe_align_block_size_triton(
47-
topk_ids,
48-
num_experts,
49-
block_size,
50-
sorted_ids_triton,
51-
expert_ids_triton,
52-
num_tokens_post_pad_triton,
53-
)
54-
55-
ops.moe_align_block_size(
56-
topk_ids,
57-
num_experts,
58-
block_size,
59-
sorted_ids_vllm,
60-
expert_ids_vllm,
61-
num_tokens_post_pad_vllm,
62-
)
63-
print(f"✅ VLLM implementation works with {num_experts} experts!")
64-
65-
# 3. compare results
66-
if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose(
67-
num_tokens_post_pad_triton, num_tokens_post_pad_vllm
68-
):
69-
print("✅ Triton and VLLM implementations match.")
70-
else:
71-
print("❌ Triton and VLLM implementations DO NOT match.")
72-
print("Triton expert_ids:", expert_ids_triton)
73-
print("VLLM expert_ids:", expert_ids_vllm)
74-
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
75-
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
76-
77-
7823
# test configurations
7924
num_tokens_range = [1, 16, 256, 4096]
8025
num_experts_range = [16, 64, 224, 256, 280, 512]
@@ -87,8 +32,8 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
8732
x_names=["num_tokens", "num_experts", "topk"],
8833
x_vals=configs,
8934
line_arg="provider",
90-
line_vals=["vllm", "triton"], # "triton"
91-
line_names=["VLLM", "Triton"], # "Triton"
35+
line_vals=["vllm"],
36+
line_names=["vLLM"],
9237
plot_name="moe-align-block-size-performance",
9338
args={},
9439
)
@@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
9843
block_size = 256
9944
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
10045

101-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
102-
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
103-
max_num_m_blocks = max_num_tokens_padded // block_size
104-
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
105-
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
106-
10746
quantiles = [0.5, 0.2, 0.8]
10847

10948
if provider == "vllm":
11049
ms, min_ms, max_ms = triton.testing.do_bench(
111-
lambda: ops.moe_align_block_size(
112-
topk_ids,
113-
num_experts,
114-
block_size,
115-
sorted_ids.clone(),
116-
expert_ids.clone(),
117-
num_tokens_post_pad.clone(),
118-
),
119-
quantiles=quantiles,
120-
)
121-
elif provider == "triton":
122-
ms, min_ms, max_ms = triton.testing.do_bench(
123-
lambda: moe_align_block_size_triton(
124-
topk_ids,
125-
num_experts,
126-
block_size,
127-
sorted_ids.clone(),
128-
expert_ids.clone(),
129-
num_tokens_post_pad.clone(),
130-
),
50+
lambda: moe_align_block_size(topk_ids, block_size, num_experts),
13151
quantiles=quantiles,
13252
)
13353

@@ -151,6 +71,4 @@ def benchmark(num_tokens, num_experts, topk, provider):
15171
)
15272
args = parser.parse_args()
15373

154-
print("Running correctness check...")
155-
check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
15674
benchmark.run(print_data=True, show_plots=True)

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 2 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -5,144 +5,8 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8-
from vllm.triton_utils import tl, triton
9-
from vllm.utils import cdiv, round_up
10-
11-
12-
@triton.jit
13-
def moe_align_block_size_stage1(
14-
topk_ids_ptr,
15-
tokens_cnts_ptr,
16-
num_experts: tl.constexpr,
17-
numel: tl.constexpr,
18-
tokens_per_thread: tl.constexpr,
19-
):
20-
pid = tl.program_id(0)
21-
22-
start_idx = pid * tokens_per_thread
23-
24-
off_c = (pid + 1) * num_experts
25-
26-
for i in range(tokens_per_thread):
27-
if start_idx + i < numel:
28-
idx = tl.load(topk_ids_ptr + start_idx + i)
29-
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
30-
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
31-
32-
33-
@triton.jit
34-
def moe_align_block_size_stage2(
35-
tokens_cnts_ptr,
36-
num_experts: tl.constexpr,
37-
):
38-
pid = tl.program_id(0)
39-
40-
last_cnt = 0
41-
for i in range(1, num_experts + 1):
42-
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
43-
last_cnt = last_cnt + token_cnt
44-
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
45-
46-
47-
@triton.jit
48-
def moe_align_block_size_stage3(
49-
total_tokens_post_pad_ptr,
50-
tokens_cnts_ptr,
51-
cumsum_ptr,
52-
num_experts: tl.constexpr,
53-
block_size: tl.constexpr,
54-
):
55-
last_cumsum = 0
56-
off_cnt = num_experts * num_experts
57-
for i in range(1, num_experts + 1):
58-
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
59-
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
60-
tl.store(cumsum_ptr + i, last_cumsum)
61-
tl.store(total_tokens_post_pad_ptr, last_cumsum)
62-
63-
64-
@triton.jit
65-
def moe_align_block_size_stage4(
66-
topk_ids_ptr,
67-
sorted_token_ids_ptr,
68-
expert_ids_ptr,
69-
tokens_cnts_ptr,
70-
cumsum_ptr,
71-
num_experts: tl.constexpr,
72-
block_size: tl.constexpr,
73-
numel: tl.constexpr,
74-
tokens_per_thread: tl.constexpr,
75-
):
76-
pid = tl.program_id(0)
77-
start_idx = tl.load(cumsum_ptr + pid)
78-
end_idx = tl.load(cumsum_ptr + pid + 1)
79-
80-
for i in range(start_idx, end_idx, block_size):
81-
tl.store(expert_ids_ptr + i // block_size, pid)
82-
83-
start_idx = pid * tokens_per_thread
84-
off_t = pid * num_experts
85-
86-
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
87-
numel)):
88-
expert_id = tl.load(topk_ids_ptr + i)
89-
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
90-
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
91-
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
92-
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
93-
94-
95-
# Triton implementation based on:
96-
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
97-
def moe_align_block_size_triton(
98-
topk_ids: torch.Tensor,
99-
num_experts: int,
100-
block_size: int,
101-
sorted_token_ids: torch.Tensor,
102-
expert_ids: torch.Tensor,
103-
num_tokens_post_pad: torch.Tensor,
104-
) -> None:
105-
numel = topk_ids.numel()
106-
grid = (num_experts, )
107-
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
108-
dtype=torch.int32,
109-
device=topk_ids.device)
110-
cumsum = torch.zeros((num_experts + 1, ),
111-
dtype=torch.int32,
112-
device=topk_ids.device)
113-
tokens_per_thread = cdiv(numel, num_experts)
114-
sorted_token_ids.fill_(numel)
115-
expert_ids.zero_()
116-
117-
moe_align_block_size_stage1[grid](
118-
topk_ids,
119-
tokens_cnts,
120-
num_experts,
121-
numel,
122-
tokens_per_thread,
123-
)
124-
moe_align_block_size_stage2[grid](
125-
tokens_cnts,
126-
num_experts,
127-
)
128-
moe_align_block_size_stage3[(1, )](
129-
num_tokens_post_pad,
130-
tokens_cnts,
131-
cumsum,
132-
num_experts,
133-
block_size,
134-
)
135-
moe_align_block_size_stage4[grid](
136-
topk_ids,
137-
sorted_token_ids,
138-
expert_ids,
139-
tokens_cnts,
140-
cumsum,
141-
num_experts,
142-
block_size,
143-
numel,
144-
tokens_per_thread,
145-
)
8+
from vllm.triton_utils import triton
9+
from vllm.utils import round_up
14610

14711

14812
def moe_align_block_size(

0 commit comments

Comments
 (0)