@@ -69,7 +69,8 @@ def token_dispatch(self,
6969 dynamic_scale_for_share : Optional [Any ] = None ,
7070 mc2_mask : Optional [torch .Tensor ] = None ,
7171 apply_router_weight_on_input : bool = False ,
72- with_quant : bool = False ):
72+ with_quant : bool = False ,
73+ dynamic_eplb : bool = False ):
7374 raise NotImplementedError ("Dispatch function not implemented." )
7475
7576 @abstractmethod
@@ -156,21 +157,20 @@ def get_dispatch_mc2_kwargs(
156157 kwargs_mc2 .update (stage1_kwargs )
157158 return kwargs_mc2
158159
159- def token_dispatch (
160- self ,
161- hidden_states : torch .Tensor ,
162- topk_weights : torch .Tensor ,
163- topk_ids : torch .Tensor ,
164- expert_map : Optional [torch .Tensor ] = None ,
165- log2phy : Optional [torch .Tensor ] = None ,
166- global_redundant_expert_num : int = 0 ,
167- shared_experts : Optional [Any ] = None ,
168- quantized_x_for_share : Optional [Any ] = None ,
169- dynamic_scale_for_share : Optional [Any ] = None ,
170- mc2_mask : Optional [torch .Tensor ] = None ,
171- apply_router_weight_on_input : bool = False ,
172- with_quant : bool = False ,
173- ):
160+ def token_dispatch (self ,
161+ hidden_states : torch .Tensor ,
162+ topk_weights : torch .Tensor ,
163+ topk_ids : torch .Tensor ,
164+ expert_map : Optional [torch .Tensor ] = None ,
165+ log2phy : Optional [torch .Tensor ] = None ,
166+ global_redundant_expert_num : int = 0 ,
167+ shared_experts : Optional [Any ] = None ,
168+ quantized_x_for_share : Optional [Any ] = None ,
169+ dynamic_scale_for_share : Optional [Any ] = None ,
170+ mc2_mask : Optional [torch .Tensor ] = None ,
171+ apply_router_weight_on_input : bool = False ,
172+ with_quant : bool = False ,
173+ dynamic_eplb : bool = False ):
174174 self .with_quant = with_quant
175175
176176 # Apply log2phy if needed
@@ -221,8 +221,10 @@ def token_dispatch(
221221 "expand_scales" : expand_scales
222222 }
223223
224+ group_list_type = 1 if dynamic_eplb else 0
225+
224226 return {
225- "group_list_type" : 0 ,
227+ "group_list_type" : group_list_type ,
226228 "hidden_states" : expand_x ,
227229 "group_list" : expert_token_nums ,
228230 "dynamic_scale" : dynamic_scale ,
@@ -336,7 +338,8 @@ def token_dispatch(self,
336338 dynamic_scale_for_share : Optional [Any ] = None ,
337339 mc2_mask : Optional [torch .Tensor ] = None ,
338340 apply_router_weight_on_input : bool = False ,
339- with_quant : bool = False ):
341+ with_quant : bool = False ,
342+ dynamic_eplb : bool = False ):
340343 self .with_quant = with_quant
341344 self .original_shape = hidden_states .shape
342345
@@ -426,7 +429,8 @@ def token_dispatch(self,
426429 dynamic_scale_for_share : Optional [Any ] = None ,
427430 mc2_mask : Optional [torch .Tensor ] = None ,
428431 apply_router_weight_on_input : bool = False ,
429- with_quant : bool = False ):
432+ with_quant : bool = False ,
433+ dynamic_eplb : bool = False ):
430434 self .bsz , _ = hidden_states .shape
431435 flatten_topk_ids = topk_ids .view (- 1 )
432436 self .sorted_topk_ids = torch .argsort (flatten_topk_ids .float ())
@@ -501,21 +505,20 @@ def __init__(self, **kwargs):
501505 self .local_expert_indices [i + 1 ] -
502506 1 ), "local_expert_indices must be continuous"
503507
504- def token_dispatch (
505- self ,
506- hidden_states : torch .Tensor ,
507- topk_weights : torch .Tensor ,
508- topk_ids : torch .Tensor ,
509- expert_map : Optional [torch .Tensor ] = None ,
510- log2phy : Optional [torch .Tensor ] = None ,
511- global_redundant_expert_num : int = 0 ,
512- shared_experts : Optional [Any ] = None ,
513- quantized_x_for_share : Optional [Any ] = None ,
514- dynamic_scale_for_share : Optional [Any ] = None ,
515- mc2_mask : Optional [torch .Tensor ] = None ,
516- apply_router_weight_on_input : bool = False ,
517- with_quant : bool = False ,
518- ):
508+ def token_dispatch (self ,
509+ hidden_states : torch .Tensor ,
510+ topk_weights : torch .Tensor ,
511+ topk_ids : torch .Tensor ,
512+ expert_map : Optional [torch .Tensor ] = None ,
513+ log2phy : Optional [torch .Tensor ] = None ,
514+ global_redundant_expert_num : int = 0 ,
515+ shared_experts : Optional [Any ] = None ,
516+ quantized_x_for_share : Optional [Any ] = None ,
517+ dynamic_scale_for_share : Optional [Any ] = None ,
518+ mc2_mask : Optional [torch .Tensor ] = None ,
519+ apply_router_weight_on_input : bool = False ,
520+ with_quant : bool = False ,
521+ dynamic_eplb : bool = False ):
519522 self .with_quant = with_quant
520523 self .hidden_shape = hidden_states .shape
521524
0 commit comments