@@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph):
159159 6 : MiB // 2 , # 512KB
160160 8 : MiB // 2 , # 512KB
161161 }
162+ # opt for a more conservative default value
163+ # when world size is not in _FI_MAX_SIZES
164+ _DEFAULT_FI_MAX_SIZE = MiB // 2
162165
163166 def call_trtllm_fused_allreduce_norm (
164167 allreduce_in : torch .Tensor ,
@@ -173,12 +176,16 @@ def call_trtllm_fused_allreduce_norm(
173176 max_token_num : int ,
174177 norm_out : Optional [torch .Tensor ] = None ,
175178 ) -> None :
176- use_flashinfer = allreduce_in .shape [0 ] * allreduce_in .shape [
177- 1 ] * allreduce_in .element_size () <= min (
178- _FI_MAX_SIZES [world_size ],
179- max_token_num * allreduce_in .shape [0 ] *
180- allreduce_in .element_size (),
181- )
179+
180+ num_tokens , hidden_size = allreduce_in .shape
181+ element_size = allreduce_in .element_size ()
182+ current_tensor_size = num_tokens * hidden_size * element_size
183+ max_fusion_size = max_token_num * hidden_size * element_size
184+ use_flashinfer = current_tensor_size <= min (
185+ _FI_MAX_SIZES .get (world_size , _DEFAULT_FI_MAX_SIZE ),
186+ max_fusion_size ,
187+ )
188+
182189 if use_flashinfer :
183190 assert (_FI_WORKSPACE_TENSOR is not None
184191 ), "Flashinfer must be enabled when using flashinfer"
0 commit comments