3232from  typing  import  Optional , Union 
3333
3434import  torch 
35- from  torch .distributed  import  ProcessGroup , all_gather ,  all_reduce 
35+ from  torch .distributed  import  ProcessGroup , all_reduce 
3636
3737from  vllm .config  import  ParallelConfig 
3838from  vllm .distributed .parallel_state  import  (get_ep_group , get_node_count ,
@@ -112,13 +112,21 @@ class EplbState:
112112    Expert load during this forward pass.  
113113    We use the token count each expert processes as the load. 
114114
115-     Shape: (num_moe_layers, num_local_physical_experts ) 
115+     Shape: (num_moe_layers, num_physical_experts ) 
116116    """ 
117117    expert_load_window : torch .Tensor 
118118    """ 
119119    A sliding window of expert load. 
120120
121-     Shape: (window_size, num_moe_layers, num_local_physical_experts) 
121+     Shape: (window_size, num_moe_layers, num_physical_experts) 
122+ 
123+     NOTE: The expert_load_view now records load for all physical experts 
124+     rather than just local experts. This ensures consistent load statistics 
125+     across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels). 
126+     The recorded load will be multiplied by dp_size when using naive all-to-all 
127+     due to each DP rank contributing the same token set to the calculation. 
128+     See: 
129+     https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856 
122130    """ 
123131    expert_load_window_step : int  =  0 
124132    """ 
@@ -232,14 +240,14 @@ def build(
232240        ).contiguous ()
233241
234242        expert_load_pass  =  torch .zeros (
235-             (model .num_moe_layers , model .num_local_physical_experts ),
243+             (model .num_moe_layers , model .num_physical_experts ),
236244            dtype = torch .int32 ,
237245            device = device ,
238246        )
239247        expert_load_window_size  =  parallel_config .eplb_window_size 
240248        expert_load_window  =  torch .zeros (
241249            (expert_load_window_size , model .num_moe_layers ,
242-              model .num_local_physical_experts ),
250+              model .num_physical_experts ),
243251            dtype = torch .int32 ,
244252            device = device ,
245253        )
@@ -353,18 +361,18 @@ def step(self,
353361            self .expert_load_pass .zero_ ()
354362
355363        if  log_stats :
356-             # `num_tokens` : (num_moe_layers,) 
357-             num_tokens  =  self .expert_load_pass .sum ( dim = - 1 )
364+             # total_expert_load_pass : (num_moe_layers, num_physical_experts ) 
365+             total_expert_load_pass  =  self .expert_load_pass .clone ( )
358366
359367            # Collect load metrics from all ranks 
360368            ep_group  =  get_ep_group ().device_group 
361369            assert  ep_group  is  not   None 
362-             num_tokens_list   =  [ 
363-                  torch . empty_like ( num_tokens )  for   _   in   range ( ep_group . size ()) 
364-             ] 
365-             all_gather ( num_tokens_list ,  num_tokens ,  group = ep_group ) 
366-             # Stack to get (num_ranks, num_moe_layers) 
367-             num_tokens_per_rank   =   torch . stack ( num_tokens_list ).float ()
370+             all_reduce ( total_expert_load_pass ,  group = ep_group ) 
371+ 
372+             # num_tokens_per_rank: (num_moe_layers, num_ranks) 
373+             num_tokens_per_rank   =   total_expert_load_pass . reshape ( 
374+                  total_expert_load_pass . shape [ 0 ],  ep_group . size (), 
375+                  - 1 ). sum ( dim = - 1 ).float ()
368376
369377            # Compute balancedness ratio: 
370378            # for each layer: 
@@ -426,17 +434,7 @@ def rearrange(self,
426434                        "(profile)"  if  is_profile  else  "" )
427435
428436        if  global_expert_load  is  None :
429-             # This mapping is only used here, so we do not store it in the state 
430-             physical_expert_start  =  ep_rank  *  model .num_local_physical_experts 
431-             physical_expert_end  =  (physical_expert_start  + 
432-                                    model .num_local_physical_experts )
433-             # (num_moe_layers, num_local_physical_experts) 
434-             local_physical_to_logical_map  =  self .physical_to_logical_map [
435-                 :,
436-                 physical_expert_start :physical_expert_end ,
437-             ]
438- 
439-             # Map the local physical expert load to global logical experts 
437+             # Map the physical expert load to global logical experts 
440438            logical_expert_load_window  =  torch .zeros (
441439                self .expert_load_window_size ,
442440                model .num_moe_layers ,
@@ -446,7 +444,7 @@ def rearrange(self,
446444            )
447445            logical_expert_load_window .scatter_add_ (
448446                dim = - 1 ,
449-                 index = local_physical_to_logical_map .unsqueeze (0 ).expand_as (
447+                 index = self . physical_to_logical_map .unsqueeze (0 ).expand_as (
450448                    self .expert_load_window ).long (),
451449                src = self .expert_load_window ,
452450            )
@@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
618616            if  is_same_node  and  node_assignment [other_rank ] ==  0 :
619617                node_assignment [other_rank ] =  next_node_id 
620618
621-     return  next_node_id 
619+     return  next_node_id 
0 commit comments