Skip to content

Commit b2b0940

Browse files
committed
bugfix when max_num_seqs=14 in mtp=2 scenario
Signed-off-by: zouyida2052 <[email protected]>
1 parent c4cd1de commit b2b0940

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,22 @@ def _init_mc2_tokens_capacity(self):
117117
# NOTE: To be clear, we need to make sure that during graph capture, the number of
118118
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
119119
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
120-
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
121120
tp_size = self.parallel_config.tensor_parallel_size
122121
# Use integer arithmetic for ceiling division.
123-
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
124-
# maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
125-
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
122+
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
123+
self.max_num_reqs, tp_size)
124+
self.mc2_tokens_capacity = max_graph_batch_size
125+
126+
if get_ascend_soc_version(
127+
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
128+
logger.error(
129+
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
130+
)
131+
if get_ascend_soc_version(
132+
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
133+
logger.error(
134+
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
135+
)
126136

127137
def _sync_metadata_across_dp(
128138
self, num_tokens: int,
@@ -464,6 +474,17 @@ def init_torchair_graph_batch_sizes(self):
464474
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
465475
start_graph_batch_size *= 2
466476

477+
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
478+
tp_size):
479+
cur_graph_batch_size = (old_graph_batch_size + tp_size -
480+
1) // tp_size * tp_size
481+
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
482+
# Both adapter multi-dp and FIA operator
483+
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
484+
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
485+
// math.gcd(tp_size, old_graph_batch_size)
486+
return cur_graph_batch_size
487+
467488
def select_torchair_padded_batch_size(self, batch_size: int):
468489
for padded_batch_size in self.torchair_graph_batch_sizes:
469490
if batch_size <= padded_batch_size:
@@ -515,13 +536,8 @@ def _align_graph_size_divisible_by_tp_size(self):
515536
tp_size = self.parallel_config.tensor_parallel_size
516537
new_graph_batch_sizes = []
517538
for graph_batch_size in self.torchair_graph_batch_sizes:
518-
cur_graph_batch_size = (graph_batch_size + tp_size -
519-
1) // tp_size * tp_size
520-
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
521-
# Both adapter multi-dp and FIA operator
522-
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
523-
cur_graph_batch_size = (tp_size * graph_batch_size) \
524-
// math.gcd(tp_size, graph_batch_size)
539+
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
540+
graph_batch_size, tp_size)
525541
if cur_graph_batch_size not in new_graph_batch_sizes and \
526542
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
527543
new_graph_batch_sizes.append(cur_graph_batch_size)

0 commit comments

Comments
 (0)