@@ -117,12 +117,22 @@ def _init_mc2_tokens_capacity(self):
117117 # NOTE: To be clear, we need to make sure that during graph capture, the number of
118118 # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
119119 # the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
120- max_num_tokens = self .max_num_reqs * self .uniform_decode_query_len
121120 tp_size = self .parallel_config .tensor_parallel_size
122121 # Use integer arithmetic for ceiling division.
123- num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1 ) // tp_size
124- # maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
125- self .mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
122+ max_graph_batch_size = self .calculate_new_torchair_graph_batch_size (
123+ self .max_num_reqs , tp_size )
124+ self .mc2_tokens_capacity = max_graph_batch_size
125+
126+ if get_ascend_soc_version (
127+ ) == AscendSocVersion .A3 and self .mc2_tokens_capacity > 512 :
128+ logger .error (
129+ f"A3: the max number of tokens must smaller then 512, but now is { self .mc2_tokens_capacity } "
130+ )
131+ if get_ascend_soc_version (
132+ ) == AscendSocVersion .A2 and self .mc2_tokens_capacity > 256 :
133+ logger .error (
134+ f"A2: the max number of tokens must smaller then 256, but now is { self .mc2_tokens_capacity } "
135+ )
126136
127137 def _sync_metadata_across_dp (
128138 self , num_tokens : int ,
@@ -464,6 +474,17 @@ def init_torchair_graph_batch_sizes(self):
464474 self .torchair_graph_batch_sizes .append (start_graph_batch_size )
465475 start_graph_batch_size *= 2
466476
477+ def calculate_new_torchair_graph_batch_size (self , old_graph_batch_size ,
478+ tp_size ):
479+ cur_graph_batch_size = (old_graph_batch_size + tp_size -
480+ 1 ) // tp_size * tp_size
481+ # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
482+ # Both adapter multi-dp and FIA operator
483+ if self .speculative_config is not None and self .speculative_config .num_speculative_tokens > 1 :
484+ cur_graph_batch_size = (tp_size * old_graph_batch_size ) \
485+ // math .gcd (tp_size , old_graph_batch_size )
486+ return cur_graph_batch_size
487+
467488 def select_torchair_padded_batch_size (self , batch_size : int ):
468489 for padded_batch_size in self .torchair_graph_batch_sizes :
469490 if batch_size <= padded_batch_size :
@@ -515,13 +536,8 @@ def _align_graph_size_divisible_by_tp_size(self):
515536 tp_size = self .parallel_config .tensor_parallel_size
516537 new_graph_batch_sizes = []
517538 for graph_batch_size in self .torchair_graph_batch_sizes :
518- cur_graph_batch_size = (graph_batch_size + tp_size -
519- 1 ) // tp_size * tp_size
520- # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
521- # Both adapter multi-dp and FIA operator
522- if self .speculative_config is not None and self .speculative_config .num_speculative_tokens > 1 :
523- cur_graph_batch_size = (tp_size * graph_batch_size ) \
524- // math .gcd (tp_size , graph_batch_size )
539+ cur_graph_batch_size = self .calculate_new_torchair_graph_batch_size (
540+ graph_batch_size , tp_size )
525541 if cur_graph_batch_size not in new_graph_batch_sizes and \
526542 cur_graph_batch_size <= self .scheduler_config .max_num_batched_tokens :
527543 new_graph_batch_sizes .append (cur_graph_batch_size )
0 commit comments