Skip to content

Commit 71dfce6

Browse files
hjjqwzhao18
andauthored
[Kernel] Refactor FlashInfer allreduce for mnnvl backend (vllm-project#34109)
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
1 parent 2aa4140 commit 71dfce6

File tree

7 files changed

+592
-179
lines changed

7 files changed

+592
-179
lines changed

benchmarks/kernels/benchmark_device_communicators.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
from torch.distributed import ProcessGroup
3131

3232
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
33+
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
34+
FlashInferAllReduce,
35+
)
3336
from vllm.distributed.device_communicators.pynccl import (
3437
PyNcclCommunicator,
3538
register_nccl_symmetric_ops,
@@ -44,7 +47,7 @@
4447
logger = init_logger(__name__)
4548

4649
# Default sequence lengths to benchmark
47-
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
50+
DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192]
4851

4952
# Fixed hidden size and dtype for all benchmarks
5053
HIDDEN_SIZE = 8192
@@ -81,6 +84,7 @@ def __init__(
8184
self.symm_mem_comm = None
8285
self.symm_mem_comm_multimem = None
8386
self.symm_mem_comm_two_shot = None
87+
self.fi_ar_comm = None
8488

8589
self._init_communicators()
8690

@@ -161,6 +165,22 @@ def _init_communicators(self):
161165
)
162166
self.symm_mem_comm_two_shot = None
163167

168+
try:
169+
self.fi_ar_comm = FlashInferAllReduce(
170+
group=self.cpu_group,
171+
device=self.device,
172+
)
173+
if not self.fi_ar_comm.disabled:
174+
logger.info("Rank %s: FlashInferAllReduce initialized", self.rank)
175+
else:
176+
logger.info("Rank %s: FlashInferAllReduce disabled", self.rank)
177+
self.fi_ar_comm = None
178+
except Exception as e:
179+
logger.warning(
180+
"Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e
181+
)
182+
self.fi_ar_comm = None
183+
164184
def benchmark_allreduce(
165185
self, sequence_length: int, num_warmup: int, num_trials: int
166186
) -> dict[str, float]:
@@ -180,7 +200,8 @@ def benchmark_allreduce(
180200
lambda t, c=comm: c.custom_all_reduce(t),
181201
lambda t, c=comm: c.should_custom_ar(t),
182202
comm.capture(),
183-
"1stage", # env variable value
203+
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"},
204+
None, # no destroy function
184205
)
185206
)
186207
# CustomAllreduce two-shot
@@ -190,7 +211,8 @@ def benchmark_allreduce(
190211
lambda t, c=comm: c.custom_all_reduce(t),
191212
lambda t, c=comm: c.should_custom_ar(t),
192213
comm.capture(),
193-
"2stage", # env variable value
214+
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"},
215+
None, # no destroy function
194216
)
195217
)
196218

@@ -202,7 +224,8 @@ def benchmark_allreduce(
202224
lambda t, c=comm: c.all_reduce(t),
203225
lambda t: True, # Always available if initialized
204226
nullcontext(),
205-
None, # no env variable needed
227+
{}, # no env variable needed
228+
None, # no destroy function
206229
)
207230
)
208231
communicators.append(
@@ -211,7 +234,8 @@ def benchmark_allreduce(
211234
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
212235
lambda t: True, # Always available if initialized
213236
nullcontext(),
214-
None, # no env variable needed
237+
{}, # no env variable needed
238+
None, # no destroy function
215239
)
216240
)
217241

@@ -223,7 +247,8 @@ def benchmark_allreduce(
223247
lambda t, c=comm: c.all_reduce(t),
224248
lambda t, c=comm: c.should_use_symm_mem(t),
225249
nullcontext(),
226-
None, # no env variable needed
250+
{}, # no env variable needed
251+
None, # no destroy function
227252
)
228253
)
229254

@@ -235,29 +260,67 @@ def benchmark_allreduce(
235260
lambda t, c=comm: c.all_reduce(t),
236261
lambda t, c=comm: c.should_use_symm_mem(t),
237262
nullcontext(),
238-
None, # no env variable needed
263+
{}, # no env variable needed
264+
None, # no destroy function needed
239265
)
240266
)
241267

242-
# Benchmark each communicator
243-
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
244-
# Set environment variable if needed
245-
if env_var is not None:
246-
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
247-
else:
248-
# Clear the environment variable to avoid interference
249-
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
250-
251-
latency = self.benchmark_allreduce_single(
252-
sequence_length,
253-
allreduce_fn,
254-
should_use_fn,
255-
context,
256-
num_warmup,
257-
num_trials,
268+
if self.fi_ar_comm is not None:
269+
comm = self.fi_ar_comm
270+
communicators.append(
271+
(
272+
"flashinfer_trtllm",
273+
lambda t, c=comm: c.all_reduce(t),
274+
lambda t, c=comm: c.should_use_fi_ar(t),
275+
nullcontext(),
276+
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"},
277+
lambda c=comm: c.destroy(),
278+
)
258279
)
259-
if latency is not None:
260-
results[name] = latency
280+
communicators.append(
281+
(
282+
"flashinfer_mnnvl",
283+
lambda t, c=comm: c.all_reduce(t),
284+
lambda t, c=comm: c.should_use_fi_ar(t),
285+
nullcontext(),
286+
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"},
287+
lambda c=comm: c.destroy(),
288+
)
289+
)
290+
291+
# Benchmark each communicator
292+
for (
293+
name,
294+
allreduce_fn,
295+
should_use_fn,
296+
context,
297+
env_dict,
298+
destroy_fn,
299+
) in communicators:
300+
# Save original values and apply new environment variables
301+
saved_env = {key: os.environ.get(key) for key in env_dict}
302+
for key, value in env_dict.items():
303+
os.environ[key] = value
304+
try:
305+
latency = self.benchmark_allreduce_single(
306+
sequence_length,
307+
allreduce_fn,
308+
should_use_fn,
309+
context,
310+
num_warmup,
311+
num_trials,
312+
)
313+
if latency is not None:
314+
results[name] = latency
315+
finally:
316+
if destroy_fn is not None:
317+
destroy_fn()
318+
# Restore environment variables to their original state
319+
for key, original_value in saved_env.items():
320+
if original_value is None:
321+
os.environ.pop(key, None)
322+
else:
323+
os.environ[key] = original_value
261324

262325
return results
263326

0 commit comments

Comments
 (0)