5
5
6
6
import torch
7
7
8
- from vllm import _custom_ops as ops
9
8
from vllm .model_executor .layers .fused_moe .moe_align_block_size import (
10
- moe_align_block_size_triton ,
9
+ moe_align_block_size ,
11
10
)
12
11
from vllm .triton_utils import triton
13
12
@@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
21
20
)
22
21
23
22
24
- def check_correctness (num_tokens , num_experts = 256 , block_size = 256 , topk = 8 ):
25
- """
26
- Verifies vllm vs. Triton
27
- """
28
- topk_ids = get_topk_ids (num_tokens , num_experts , topk )
29
-
30
- # 1. malloc space for triton and vllm
31
- # malloc enough space (max_num_tokens_padded) for the sorted ids
32
- max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
33
- sorted_ids_triton = torch .empty (
34
- (max_num_tokens_padded ,), dtype = torch .int32 , device = "cuda"
35
- )
36
- expert_ids_triton = torch .empty (
37
- (max_num_tokens_padded // block_size ,), dtype = torch .int32 , device = "cuda"
38
- )
39
- num_tokens_post_pad_triton = torch .empty ((1 ,), dtype = torch .int32 , device = "cuda" )
40
-
41
- sorted_ids_vllm = torch .empty_like (sorted_ids_triton )
42
- expert_ids_vllm = torch .empty_like (expert_ids_triton )
43
- num_tokens_post_pad_vllm = torch .empty_like (num_tokens_post_pad_triton )
44
-
45
- # 2. run implementations
46
- moe_align_block_size_triton (
47
- topk_ids ,
48
- num_experts ,
49
- block_size ,
50
- sorted_ids_triton ,
51
- expert_ids_triton ,
52
- num_tokens_post_pad_triton ,
53
- )
54
-
55
- ops .moe_align_block_size (
56
- topk_ids ,
57
- num_experts ,
58
- block_size ,
59
- sorted_ids_vllm ,
60
- expert_ids_vllm ,
61
- num_tokens_post_pad_vllm ,
62
- )
63
- print (f"✅ VLLM implementation works with { num_experts } experts!" )
64
-
65
- # 3. compare results
66
- if torch .allclose (expert_ids_triton , expert_ids_vllm ) and torch .allclose (
67
- num_tokens_post_pad_triton , num_tokens_post_pad_vllm
68
- ):
69
- print ("✅ Triton and VLLM implementations match." )
70
- else :
71
- print ("❌ Triton and VLLM implementations DO NOT match." )
72
- print ("Triton expert_ids:" , expert_ids_triton )
73
- print ("VLLM expert_ids:" , expert_ids_vllm )
74
- print ("Triton num_tokens_post_pad:" , num_tokens_post_pad_triton )
75
- print ("VLLM num_tokens_post_pad:" , num_tokens_post_pad_vllm )
76
-
77
-
78
23
# test configurations
79
24
num_tokens_range = [1 , 16 , 256 , 4096 ]
80
25
num_experts_range = [16 , 64 , 224 , 256 , 280 , 512 ]
@@ -87,8 +32,8 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
87
32
x_names = ["num_tokens" , "num_experts" , "topk" ],
88
33
x_vals = configs ,
89
34
line_arg = "provider" ,
90
- line_vals = ["vllm" , "triton" ], # "triton"
91
- line_names = ["VLLM" , "Triton" ], # "Triton"
35
+ line_vals = ["vllm" ],
36
+ line_names = ["vLLM" ],
92
37
plot_name = "moe-align-block-size-performance" ,
93
38
args = {},
94
39
)
@@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
98
43
block_size = 256
99
44
topk_ids = get_topk_ids (num_tokens , num_experts , topk )
100
45
101
- max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
102
- sorted_ids = torch .empty ((max_num_tokens_padded ,), dtype = torch .int32 , device = "cuda" )
103
- max_num_m_blocks = max_num_tokens_padded // block_size
104
- expert_ids = torch .empty ((max_num_m_blocks ,), dtype = torch .int32 , device = "cuda" )
105
- num_tokens_post_pad = torch .empty ((1 ,), dtype = torch .int32 , device = "cuda" )
106
-
107
46
quantiles = [0.5 , 0.2 , 0.8 ]
108
47
109
48
if provider == "vllm" :
110
49
ms , min_ms , max_ms = triton .testing .do_bench (
111
- lambda : ops .moe_align_block_size (
112
- topk_ids ,
113
- num_experts ,
114
- block_size ,
115
- sorted_ids .clone (),
116
- expert_ids .clone (),
117
- num_tokens_post_pad .clone (),
118
- ),
119
- quantiles = quantiles ,
120
- )
121
- elif provider == "triton" :
122
- ms , min_ms , max_ms = triton .testing .do_bench (
123
- lambda : moe_align_block_size_triton (
124
- topk_ids ,
125
- num_experts ,
126
- block_size ,
127
- sorted_ids .clone (),
128
- expert_ids .clone (),
129
- num_tokens_post_pad .clone (),
130
- ),
50
+ lambda : moe_align_block_size (topk_ids , block_size , num_experts ),
131
51
quantiles = quantiles ,
132
52
)
133
53
@@ -151,6 +71,4 @@ def benchmark(num_tokens, num_experts, topk, provider):
151
71
)
152
72
args = parser .parse_args ()
153
73
154
- print ("Running correctness check..." )
155
- check_correctness (num_tokens = 1024 , num_experts = args .num_experts , topk = args .topk )
156
74
benchmark .run (print_data = True , show_plots = True )
0 commit comments