55
66import torch
77
8- from vllm import _custom_ops as ops
98from vllm .model_executor .layers .fused_moe .moe_align_block_size import (
10- moe_align_block_size_triton ,
9+ moe_align_block_size ,
1110)
1211from vllm .triton_utils import triton
1312
@@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
2120 )
2221
2322
24- def check_correctness (num_tokens , num_experts = 256 , block_size = 256 , topk = 8 ):
25- """
26- Verifies vllm vs. Triton
27- """
28- topk_ids = get_topk_ids (num_tokens , num_experts , topk )
29-
30- # 1. malloc space for triton and vllm
31- # malloc enough space (max_num_tokens_padded) for the sorted ids
32- max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
33- sorted_ids_triton = torch .empty (
34- (max_num_tokens_padded ,), dtype = torch .int32 , device = "cuda"
35- )
36- expert_ids_triton = torch .empty (
37- (max_num_tokens_padded // block_size ,), dtype = torch .int32 , device = "cuda"
38- )
39- num_tokens_post_pad_triton = torch .empty ((1 ,), dtype = torch .int32 , device = "cuda" )
40-
41- sorted_ids_vllm = torch .empty_like (sorted_ids_triton )
42- expert_ids_vllm = torch .empty_like (expert_ids_triton )
43- num_tokens_post_pad_vllm = torch .empty_like (num_tokens_post_pad_triton )
44-
45- # 2. run implementations
46- moe_align_block_size_triton (
47- topk_ids ,
48- num_experts ,
49- block_size ,
50- sorted_ids_triton ,
51- expert_ids_triton ,
52- num_tokens_post_pad_triton ,
53- )
54-
55- ops .moe_align_block_size (
56- topk_ids ,
57- num_experts ,
58- block_size ,
59- sorted_ids_vllm ,
60- expert_ids_vllm ,
61- num_tokens_post_pad_vllm ,
62- )
63- print (f"✅ VLLM implementation works with { num_experts } experts!" )
64-
65- # 3. compare results
66- if torch .allclose (expert_ids_triton , expert_ids_vllm ) and torch .allclose (
67- num_tokens_post_pad_triton , num_tokens_post_pad_vllm
68- ):
69- print ("✅ Triton and VLLM implementations match." )
70- else :
71- print ("❌ Triton and VLLM implementations DO NOT match." )
72- print ("Triton expert_ids:" , expert_ids_triton )
73- print ("VLLM expert_ids:" , expert_ids_vllm )
74- print ("Triton num_tokens_post_pad:" , num_tokens_post_pad_triton )
75- print ("VLLM num_tokens_post_pad:" , num_tokens_post_pad_vllm )
76-
77-
7823# test configurations
7924num_tokens_range = [1 , 16 , 256 , 4096 ]
8025num_experts_range = [16 , 64 , 224 , 256 , 280 , 512 ]
@@ -87,8 +32,8 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
8732 x_names = ["num_tokens" , "num_experts" , "topk" ],
8833 x_vals = configs ,
8934 line_arg = "provider" ,
90- line_vals = ["vllm" , "triton" ], # "triton"
91- line_names = ["VLLM" , "Triton" ], # "Triton"
35+ line_vals = ["vllm" ],
36+ line_names = ["vLLM" ],
9237 plot_name = "moe-align-block-size-performance" ,
9338 args = {},
9439 )
@@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
9843 block_size = 256
9944 topk_ids = get_topk_ids (num_tokens , num_experts , topk )
10045
101- max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
102- sorted_ids = torch .empty ((max_num_tokens_padded ,), dtype = torch .int32 , device = "cuda" )
103- max_num_m_blocks = max_num_tokens_padded // block_size
104- expert_ids = torch .empty ((max_num_m_blocks ,), dtype = torch .int32 , device = "cuda" )
105- num_tokens_post_pad = torch .empty ((1 ,), dtype = torch .int32 , device = "cuda" )
106-
10746 quantiles = [0.5 , 0.2 , 0.8 ]
10847
10948 if provider == "vllm" :
11049 ms , min_ms , max_ms = triton .testing .do_bench (
111- lambda : ops .moe_align_block_size (
112- topk_ids ,
113- num_experts ,
114- block_size ,
115- sorted_ids .clone (),
116- expert_ids .clone (),
117- num_tokens_post_pad .clone (),
118- ),
119- quantiles = quantiles ,
120- )
121- elif provider == "triton" :
122- ms , min_ms , max_ms = triton .testing .do_bench (
123- lambda : moe_align_block_size_triton (
124- topk_ids ,
125- num_experts ,
126- block_size ,
127- sorted_ids .clone (),
128- expert_ids .clone (),
129- num_tokens_post_pad .clone (),
130- ),
50+ lambda : moe_align_block_size (topk_ids , block_size , num_experts ),
13151 quantiles = quantiles ,
13252 )
13353
@@ -151,6 +71,4 @@ def benchmark(num_tokens, num_experts, topk, provider):
15171 )
15272 args = parser .parse_args ()
15373
154- print ("Running correctness check..." )
155- check_correctness (num_tokens = 1024 , num_experts = args .num_experts , topk = args .topk )
15674 benchmark .run (print_data = True , show_plots = True )
0 commit comments