3030from torch .distributed import ProcessGroup
3131
3232from vllm .distributed .device_communicators .custom_all_reduce import CustomAllreduce
33+ from vllm .distributed .device_communicators .flashinfer_all_reduce import (
34+ FlashInferAllReduce ,
35+ )
3336from vllm .distributed .device_communicators .pynccl import (
3437 PyNcclCommunicator ,
3538 register_nccl_symmetric_ops ,
4447logger = init_logger (__name__ )
4548
4649# Default sequence lengths to benchmark
47- DEFAULT_SEQUENCE_LENGTHS = [128 , 512 , 1024 , 2048 , 4096 , 8192 ]
50+ DEFAULT_SEQUENCE_LENGTHS = [16 , 64 , 128 , 512 , 1024 , 2048 , 4096 , 8192 ]
4851
4952# Fixed hidden size and dtype for all benchmarks
5053HIDDEN_SIZE = 8192
@@ -81,6 +84,7 @@ def __init__(
8184 self .symm_mem_comm = None
8285 self .symm_mem_comm_multimem = None
8386 self .symm_mem_comm_two_shot = None
87+ self .fi_ar_comm = None
8488
8589 self ._init_communicators ()
8690
@@ -161,6 +165,22 @@ def _init_communicators(self):
161165 )
162166 self .symm_mem_comm_two_shot = None
163167
168+ try :
169+ self .fi_ar_comm = FlashInferAllReduce (
170+ group = self .cpu_group ,
171+ device = self .device ,
172+ )
173+ if not self .fi_ar_comm .disabled :
174+ logger .info ("Rank %s: FlashInferAllReduce initialized" , self .rank )
175+ else :
176+ logger .info ("Rank %s: FlashInferAllReduce disabled" , self .rank )
177+ self .fi_ar_comm = None
178+ except Exception as e :
179+ logger .warning (
180+ "Rank %s: Failed to initialize FlashInferAllReduce: %s" , self .rank , e
181+ )
182+ self .fi_ar_comm = None
183+
164184 def benchmark_allreduce (
165185 self , sequence_length : int , num_warmup : int , num_trials : int
166186 ) -> dict [str , float ]:
@@ -180,7 +200,8 @@ def benchmark_allreduce(
180200 lambda t , c = comm : c .custom_all_reduce (t ),
181201 lambda t , c = comm : c .should_custom_ar (t ),
182202 comm .capture (),
183- "1stage" , # env variable value
203+ {"VLLM_CUSTOM_ALLREDUCE_ALGO" : "1stage" },
204+ None , # no destroy function
184205 )
185206 )
186207 # CustomAllreduce two-shot
@@ -190,7 +211,8 @@ def benchmark_allreduce(
190211 lambda t , c = comm : c .custom_all_reduce (t ),
191212 lambda t , c = comm : c .should_custom_ar (t ),
192213 comm .capture (),
193- "2stage" , # env variable value
214+ {"VLLM_CUSTOM_ALLREDUCE_ALGO" : "2stage" },
215+ None , # no destroy function
194216 )
195217 )
196218
@@ -202,7 +224,8 @@ def benchmark_allreduce(
202224 lambda t , c = comm : c .all_reduce (t ),
203225 lambda t : True , # Always available if initialized
204226 nullcontext (),
205- None , # no env variable needed
227+ {}, # no env variable needed
228+ None , # no destroy function
206229 )
207230 )
208231 communicators .append (
@@ -211,7 +234,8 @@ def benchmark_allreduce(
211234 lambda t : torch .ops .vllm .all_reduce_symmetric_with_copy (t ),
212235 lambda t : True , # Always available if initialized
213236 nullcontext (),
214- None , # no env variable needed
237+ {}, # no env variable needed
238+ None , # no destroy function
215239 )
216240 )
217241
@@ -223,7 +247,8 @@ def benchmark_allreduce(
223247 lambda t , c = comm : c .all_reduce (t ),
224248 lambda t , c = comm : c .should_use_symm_mem (t ),
225249 nullcontext (),
226- None , # no env variable needed
250+ {}, # no env variable needed
251+ None , # no destroy function
227252 )
228253 )
229254
@@ -235,29 +260,67 @@ def benchmark_allreduce(
235260 lambda t , c = comm : c .all_reduce (t ),
236261 lambda t , c = comm : c .should_use_symm_mem (t ),
237262 nullcontext (),
238- None , # no env variable needed
263+ {}, # no env variable needed
264+ None , # no destroy function needed
239265 )
240266 )
241267
242- # Benchmark each communicator
243- for name , allreduce_fn , should_use_fn , context , env_var in communicators :
244- # Set environment variable if needed
245- if env_var is not None :
246- os .environ ["VLLM_CUSTOM_ALLREDUCE_ALGO" ] = env_var
247- else :
248- # Clear the environment variable to avoid interference
249- os .environ .pop ("VLLM_CUSTOM_ALLREDUCE_ALGO" , None )
250-
251- latency = self .benchmark_allreduce_single (
252- sequence_length ,
253- allreduce_fn ,
254- should_use_fn ,
255- context ,
256- num_warmup ,
257- num_trials ,
268+ if self .fi_ar_comm is not None :
269+ comm = self .fi_ar_comm
270+ communicators .append (
271+ (
272+ "flashinfer_trtllm" ,
273+ lambda t , c = comm : c .all_reduce (t ),
274+ lambda t , c = comm : c .should_use_fi_ar (t ),
275+ nullcontext (),
276+ {"VLLM_FLASHINFER_ALLREDUCE_BACKEND" : "trtllm" },
277+ lambda c = comm : c .destroy (),
278+ )
258279 )
259- if latency is not None :
260- results [name ] = latency
280+ communicators .append (
281+ (
282+ "flashinfer_mnnvl" ,
283+ lambda t , c = comm : c .all_reduce (t ),
284+ lambda t , c = comm : c .should_use_fi_ar (t ),
285+ nullcontext (),
286+ {"VLLM_FLASHINFER_ALLREDUCE_BACKEND" : "mnnvl" },
287+ lambda c = comm : c .destroy (),
288+ )
289+ )
290+
291+ # Benchmark each communicator
292+ for (
293+ name ,
294+ allreduce_fn ,
295+ should_use_fn ,
296+ context ,
297+ env_dict ,
298+ destroy_fn ,
299+ ) in communicators :
300+ # Save original values and apply new environment variables
301+ saved_env = {key : os .environ .get (key ) for key in env_dict }
302+ for key , value in env_dict .items ():
303+ os .environ [key ] = value
304+ try :
305+ latency = self .benchmark_allreduce_single (
306+ sequence_length ,
307+ allreduce_fn ,
308+ should_use_fn ,
309+ context ,
310+ num_warmup ,
311+ num_trials ,
312+ )
313+ if latency is not None :
314+ results [name ] = latency
315+ finally :
316+ if destroy_fn is not None :
317+ destroy_fn ()
318+ # Restore environment variables to their original state
319+ for key , original_value in saved_env .items ():
320+ if original_value is None :
321+ os .environ .pop (key , None )
322+ else :
323+ os .environ [key ] = original_value
261324
262325 return results
263326
0 commit comments