Skip to content

Commit adadd50

Browse files
authored
bugfix for mtp fullgraph (#3845)
### What this PR does / why we need it? bugfix for mtp fullgraph ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@83f478b Signed-off-by: zouyida2052 <[email protected]>
1 parent d6ef3df commit adadd50

File tree

3 files changed

+51
-27
lines changed

3 files changed

+51
-27
lines changed

vllm_ascend/platform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
306306
**********************************************************************************\033[0m
307307
"""
308308
logger.warning(warning_message)
309+
update_aclgraph_sizes(vllm_config)
309310
else:
310311
logger.info(
311312
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -343,6 +344,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
343344
**********************************************************************************\033[0m
344345
"""
345346
logger.warning(warning_message)
347+
update_aclgraph_sizes(vllm_config)
346348
else:
347349
logger.info(
348350
"%s cudagraph_mode is not support on NPU. falling back to NONE",

vllm_ascend/utils.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,12 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
349349

350350
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
351351
"""Update ACL graph capture sizes based on hardware limitations"""
352+
from vllm.config.compilation import CUDAGraphMode
353+
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
354+
if vllm_config.speculative_config is not None and \
355+
vllm_config.speculative_config.num_speculative_tokens > 1:
356+
_update_spec_aclgraph_sizes(vllm_config)
357+
return
352358
# NOTE: Currently, we can only capture 1800 graphs at most,
353359
# due to the limitation of ACL graph. This number is bounded by
354360
# the number of streams, which is 2048, we save 248 streams
@@ -459,28 +465,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
459465
vllm_config.model_config.architectures[0], num_hidden_layers,
460466
len(original_sizes))
461467

468+
if vllm_config.speculative_config is not None and \
469+
vllm_config.speculative_config.num_speculative_tokens > 1:
470+
_update_spec_aclgraph_sizes(vllm_config)
471+
472+
473+
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
462474
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
463475
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
464476
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
465-
if vllm_config.speculative_config is not None and \
466-
vllm_config.speculative_config.num_speculative_tokens > 1:
467-
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
468-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
469-
original_sizes, compilation_config.cudagraph_capture_sizes = \
470-
compilation_config.cudagraph_capture_sizes, None
471-
assert len(original_sizes) > 0
472-
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
473-
enlarged_sizes = [(num_speculative_tokens + 1) * size
474-
for size in original_sizes]
475-
if vllm_version_is("0.11.0"):
476-
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
477-
else:
478-
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
479-
logger.info(
480-
"Adjusted ACL graphs: %s → %s for speculative decoding",
481-
original_sizes, enlarged_sizes)
477+
from vllm.config.compilation import CUDAGraphMode
478+
compilation_config = vllm_config.compilation_config
479+
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
480+
uniform_decode_query_len = num_speculative_tokens + 1
481+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
482+
max_num_tokens = max_num_seqs * uniform_decode_query_len
483+
original_sizes, compilation_config.cudagraph_capture_sizes = \
484+
compilation_config.cudagraph_capture_sizes, None
485+
486+
assert len(original_sizes) > 0
487+
488+
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
489+
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
490+
enlarged_sizes = [
491+
size * uniform_decode_query_len for size in original_sizes
492+
if size >= uniform_decode_query_len and size *
493+
uniform_decode_query_len <= max_num_tokens
494+
]
495+
if vllm_version_is("0.11.0"):
496+
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
482497
else:
483-
compilation_config.cudagraph_capture_sizes = original_sizes
498+
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
499+
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
500+
original_sizes, enlarged_sizes)
501+
elif original_sizes[0] < max_num_tokens:
502+
enlarged_sizes = [
503+
size * uniform_decode_query_len for size in original_sizes
504+
]
505+
if vllm_version_is("0.11.0"):
506+
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
507+
else:
508+
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
509+
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
510+
original_sizes, enlarged_sizes)
511+
else:
512+
compilation_config.cudagraph_capture_sizes = original_sizes
484513

485514

486515
# TODO(wxy): Move to ops module

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3885,7 +3885,7 @@ def _capture_model(self):
38853885
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
38863886
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
38873887

3888-
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
3888+
compilation_cases = sorted(self.aclgraph_batch_sizes)
38893889

38903890
try:
38913891
self._capture_aclgraphs(
@@ -3914,14 +3914,7 @@ def _capture_model(self):
39143914

39153915
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
39163916
aclgraph_mode.separate_routine():
3917-
max_num_tokens = self.scheduler_config.max_num_seqs * \
3918-
self.uniform_decode_query_len
3919-
decode_cudagraph_batch_sizes = [
3920-
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
3921-
and x >= self.uniform_decode_query_len
3922-
]
3923-
compilation_cases_decode = list(
3924-
reversed(decode_cudagraph_batch_sizes))
3917+
compilation_cases_decode = sorted(self.aclgraph_batch_sizes)
39253918
self._capture_aclgraphs(
39263919
compilation_cases=compilation_cases_decode,
39273920
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)