Skip to content

Commit 1966885

Browse files
authored
mfix bug when max_seqs=14 in mtp=2 scenario and raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_lentp (#3910)
### What this PR does / why we need it? 1. Revert [bugfix for mtp in fullgraph](0948483) and support it when vllm supports 2. raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_len 3. bugfix when max_num_seqs=14 in mtp=2 scenario ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: zouyida2052 <[email protected]>
1 parent 35a913c commit 1966885

File tree

4 files changed

+63
-61
lines changed

4 files changed

+63
-61
lines changed

vllm_ascend/platform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
306306
**********************************************************************************\033[0m
307307
"""
308308
logger.warning(warning_message)
309-
update_aclgraph_sizes(vllm_config)
310309
else:
311310
logger.info(
312311
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -344,7 +343,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
344343
**********************************************************************************\033[0m
345344
"""
346345
logger.warning(warning_message)
347-
update_aclgraph_sizes(vllm_config)
348346
else:
349347
logger.info(
350348
"%s cudagraph_mode is not support on NPU. falling back to NONE",

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,23 @@ def _init_mc2_tokens_capacity(self):
117117
# NOTE: To be clear, we need to make sure that during graph capture, the number of
118118
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
119119
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
120-
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
120+
max_num_tokens = self.parallel_config.tensor_parallel_size
121121
tp_size = self.parallel_config.tensor_parallel_size
122122
# Use integer arithmetic for ceiling division.
123-
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
124-
# maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
125-
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
123+
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
124+
max_num_tokens, tp_size)
125+
self.mc2_tokens_capacity = max_graph_batch_size
126+
127+
if get_ascend_soc_version(
128+
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
129+
logger.error(
130+
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
131+
)
132+
if get_ascend_soc_version(
133+
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
134+
logger.error(
135+
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
136+
)
126137

127138
def _sync_metadata_across_dp(
128139
self, num_tokens: int,
@@ -464,6 +475,17 @@ def init_torchair_graph_batch_sizes(self):
464475
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
465476
start_graph_batch_size *= 2
466477

478+
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
479+
tp_size):
480+
cur_graph_batch_size = (old_graph_batch_size + tp_size -
481+
1) // tp_size * tp_size
482+
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
483+
# Both adapter multi-dp and FIA operator
484+
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
485+
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
486+
// math.gcd(tp_size, old_graph_batch_size)
487+
return cur_graph_batch_size
488+
467489
def select_torchair_padded_batch_size(self, batch_size: int):
468490
for padded_batch_size in self.torchair_graph_batch_sizes:
469491
if batch_size <= padded_batch_size:
@@ -515,13 +537,8 @@ def _align_graph_size_divisible_by_tp_size(self):
515537
tp_size = self.parallel_config.tensor_parallel_size
516538
new_graph_batch_sizes = []
517539
for graph_batch_size in self.torchair_graph_batch_sizes:
518-
cur_graph_batch_size = (graph_batch_size + tp_size -
519-
1) // tp_size * tp_size
520-
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
521-
# Both adapter multi-dp and FIA operator
522-
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
523-
cur_graph_batch_size = (tp_size * graph_batch_size) \
524-
// math.gcd(tp_size, graph_batch_size)
540+
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
541+
graph_batch_size, tp_size)
525542
if cur_graph_batch_size not in new_graph_batch_sizes and \
526543
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
527544
new_graph_batch_sizes.append(cur_graph_batch_size)

vllm_ascend/utils.py

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,6 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
349349

350350
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
351351
"""Update ACL graph capture sizes based on hardware limitations"""
352-
from vllm.config.compilation import CUDAGraphMode
353-
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
354-
if vllm_config.speculative_config is not None and \
355-
vllm_config.speculative_config.num_speculative_tokens > 1:
356-
_update_spec_aclgraph_sizes(vllm_config)
357-
return
358352
# NOTE: Currently, we can only capture 1800 graphs at most,
359353
# due to the limitation of ACL graph. This number is bounded by
360354
# the number of streams, which is 2048, we save 248 streams
@@ -465,51 +459,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
465459
vllm_config.model_config.architectures[0], num_hidden_layers,
466460
len(original_sizes))
467461

468-
if vllm_config.speculative_config is not None and \
469-
vllm_config.speculative_config.num_speculative_tokens > 1:
470-
_update_spec_aclgraph_sizes(vllm_config)
471-
472-
473-
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
474462
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
475463
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
476464
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
477-
from vllm.config.compilation import CUDAGraphMode
478-
compilation_config = vllm_config.compilation_config
479-
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
480-
uniform_decode_query_len = num_speculative_tokens + 1
481-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
482-
max_num_tokens = max_num_seqs * uniform_decode_query_len
483-
original_sizes, compilation_config.cudagraph_capture_sizes = \
484-
compilation_config.cudagraph_capture_sizes, None
485-
486-
assert len(original_sizes) > 0
487-
488-
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
489-
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
490-
enlarged_sizes = [
491-
size * uniform_decode_query_len for size in original_sizes
492-
if size >= uniform_decode_query_len and size *
493-
uniform_decode_query_len <= max_num_tokens
494-
]
495-
if vllm_version_is("0.11.0"):
496-
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
497-
else:
498-
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
499-
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
500-
original_sizes, enlarged_sizes)
501-
elif original_sizes[0] < max_num_tokens:
502-
enlarged_sizes = [
503-
size * uniform_decode_query_len for size in original_sizes
504-
]
505-
if vllm_version_is("0.11.0"):
506-
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
465+
if vllm_config.speculative_config is not None and \
466+
vllm_config.speculative_config.num_speculative_tokens > 1:
467+
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
468+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
469+
original_sizes, compilation_config.cudagraph_capture_sizes = \
470+
compilation_config.cudagraph_capture_sizes, None
471+
assert len(original_sizes) > 0
472+
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
473+
enlarged_sizes = [(num_speculative_tokens + 1) * size
474+
for size in original_sizes]
475+
if vllm_version_is("0.11.0"):
476+
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
477+
else:
478+
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
479+
logger.info(
480+
"Adjusted ACL graphs: %s → %s for speculative decoding",
481+
original_sizes, enlarged_sizes)
507482
else:
508-
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
509-
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
510-
original_sizes, enlarged_sizes)
511-
else:
512-
compilation_config.cudagraph_capture_sizes = original_sizes
483+
compilation_config.cudagraph_capture_sizes = original_sizes
513484

514485

515486
# TODO(wxy): Move to ops module

vllm_ascend/worker/model_runner_v1.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4035,7 +4035,23 @@ def _capture_model(self):
40354035

40364036
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
40374037
aclgraph_mode.separate_routine():
4038-
compilation_cases_decode = sorted(self.aclgraph_batch_sizes)
4038+
max_num_tokens = self.scheduler_config.max_num_seqs * \
4039+
self.uniform_decode_query_len
4040+
decode_cudagraph_batch_sizes = [
4041+
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
4042+
and x >= self.uniform_decode_query_len
4043+
]
4044+
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
4045+
# TODO: refactor this when vLLM supports mtp>1
4046+
if not all(x % self.uniform_decode_query_len == 0
4047+
for x in decode_cudagraph_batch_sizes):
4048+
raise ValueError(
4049+
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
4050+
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
4051+
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
4052+
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
4053+
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
4054+
)
40394055
self._capture_aclgraphs(
40404056
compilation_cases=compilation_cases_decode,
40414057
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)