From 3fb686fa680174448d7527a3fc761c1549250b3a Mon Sep 17 00:00:00 2001 From: XIn Li Date: Mon, 21 Jul 2025 10:55:10 -0700 Subject: [PATCH 1/3] fix flashifner enable disable calculation Signed-off-by: XIn Li --- vllm/compilation/collective_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a8b00aaf0842..e33099c39796 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -176,7 +176,7 @@ def call_trtllm_fused_allreduce_norm( use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[ 1] * allreduce_in.element_size() <= min( _FI_MAX_SIZES[world_size], - max_token_num * allreduce_in.shape[0] * + max_token_num * allreduce_in.shape[1] * allreduce_in.element_size(), ) if use_flashinfer: From bf3fa63422e8d0a8520ac6a4c751fda84775e47e Mon Sep 17 00:00:00 2001 From: XIn Li Date: Mon, 21 Jul 2025 13:38:08 -0700 Subject: [PATCH 2/3] address copilot feedback Signed-off-by: XIn Li --- vllm/compilation/collective_fusion.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index e33099c39796..6c83a22d834c 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -173,12 +173,14 @@ def call_trtllm_fused_allreduce_norm( max_token_num: int, norm_out: Optional[torch.Tensor] = None, ) -> None: - use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[ - 1] * allreduce_in.element_size() <= min( - _FI_MAX_SIZES[world_size], - max_token_num * allreduce_in.shape[1] * - allreduce_in.element_size(), - ) + + num_tokens, hidden_size = allreduce_in.shape + element_size = allreduce_in.element_size() + current_tensor_size = num_tokens * hidden_size * element_size + max_fusion_size = max_token_num * hidden_size * element_size + use_flashinfer = current_tensor_size <= min(_FI_MAX_SIZES[world_size], + max_fusion_size) + if use_flashinfer: assert (_FI_WORKSPACE_TENSOR is not None ), "Flashinfer must be enabled when using flashinfer" From 8195e6c5b1c0e37cba290a5e7cccca3780d8dde6 Mon Sep 17 00:00:00 2001 From: XIn Li Date: Mon, 21 Jul 2025 14:28:27 -0700 Subject: [PATCH 3/3] address review feedback when world size is uncommon Signed-off-by: XIn Li --- vllm/compilation/collective_fusion.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6c83a22d834c..0e7961841bd3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph): 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + # opt for a more conservative default value + # when world size is not in _FI_MAX_SIZES + _DEFAULT_FI_MAX_SIZE = MiB // 2 def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, @@ -178,8 +181,10 @@ def call_trtllm_fused_allreduce_norm( element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min(_FI_MAX_SIZES[world_size], - max_fusion_size) + use_flashinfer = current_tensor_size <= min( + _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), + max_fusion_size, + ) if use_flashinfer: assert (_FI_WORKSPACE_TENSOR is not None