Skip to content

Commit 86afe57

Browse files
committed
Revert "bugfix for mtp fullgraph (#3845)"
This reverts commit adadd50.
1 parent 655a229 commit 86afe57

File tree

3 files changed

+27
-51
lines changed

3 files changed

+27
-51
lines changed

vllm_ascend/platform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
306306
**********************************************************************************\033[0m
307307
"""
308308
logger.warning(warning_message)
309-
update_aclgraph_sizes(vllm_config)
310309
else:
311310
logger.info(
312311
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -344,7 +343,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
344343
**********************************************************************************\033[0m
345344
"""
346345
logger.warning(warning_message)
347-
update_aclgraph_sizes(vllm_config)
348346
else:
349347
logger.info(
350348
"%s cudagraph_mode is not support on NPU. falling back to NONE",

vllm_ascend/utils.py

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,6 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
349349

350350
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
351351
"""Update ACL graph capture sizes based on hardware limitations"""
352-
from vllm.config.compilation import CUDAGraphMode
353-
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
354-
if vllm_config.speculative_config is not None and \
355-
vllm_config.speculative_config.num_speculative_tokens > 1:
356-
_update_spec_aclgraph_sizes(vllm_config)
357-
return
358352
# NOTE: Currently, we can only capture 1800 graphs at most,
359353
# due to the limitation of ACL graph. This number is bounded by
360354
# the number of streams, which is 2048, we save 248 streams
@@ -465,51 +459,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
465459
vllm_config.model_config.architectures[0], num_hidden_layers,
466460
len(original_sizes))
467461

468-
if vllm_config.speculative_config is not None and \
469-
vllm_config.speculative_config.num_speculative_tokens > 1:
470-
_update_spec_aclgraph_sizes(vllm_config)
471-
472-
473-
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
474462
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
475463
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
476464
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
477-
from vllm.config.compilation import CUDAGraphMode
478-
compilation_config = vllm_config.compilation_config
479-
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
480-
uniform_decode_query_len = num_speculative_tokens + 1
481-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
482-
max_num_tokens = max_num_seqs * uniform_decode_query_len
483-
original_sizes, compilation_config.cudagraph_capture_sizes = \
484-
compilation_config.cudagraph_capture_sizes, None
485-
486-
assert len(original_sizes) > 0
487-
488-
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
489-
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
490-
enlarged_sizes = [
491-
size * uniform_decode_query_len for size in original_sizes
492-
if size >= uniform_decode_query_len and size *
493-
uniform_decode_query_len <= max_num_tokens
494-
]
495-
if vllm_version_is("0.11.0"):
496-
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
497-
else:
498-
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
499-
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
500-
original_sizes, enlarged_sizes)
501-
elif original_sizes[0] < max_num_tokens:
502-
enlarged_sizes = [
503-
size * uniform_decode_query_len for size in original_sizes
504-
]
505-
if vllm_version_is("0.11.0"):
506-
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
465+
if vllm_config.speculative_config is not None and \
466+
vllm_config.speculative_config.num_speculative_tokens > 1:
467+
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
468+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
469+
original_sizes, compilation_config.cudagraph_capture_sizes = \
470+
compilation_config.cudagraph_capture_sizes, None
471+
assert len(original_sizes) > 0
472+
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
473+
enlarged_sizes = [(num_speculative_tokens + 1) * size
474+
for size in original_sizes]
475+
if vllm_version_is("0.11.0"):
476+
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
477+
else:
478+
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
479+
logger.info(
480+
"Adjusted ACL graphs: %s → %s for speculative decoding",
481+
original_sizes, enlarged_sizes)
507482
else:
508-
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
509-
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
510-
original_sizes, enlarged_sizes)
511-
else:
512-
compilation_config.cudagraph_capture_sizes = original_sizes
483+
compilation_config.cudagraph_capture_sizes = original_sizes
513484

514485

515486
# TODO(wxy): Move to ops module

vllm_ascend/worker/model_runner_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4006,7 +4006,7 @@ def _capture_model(self):
40064006
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
40074007
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
40084008

4009-
compilation_cases = sorted(self.aclgraph_batch_sizes)
4009+
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
40104010

40114011
try:
40124012
self._capture_aclgraphs(
@@ -4035,7 +4035,14 @@ def _capture_model(self):
40354035

40364036
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
40374037
aclgraph_mode.separate_routine():
4038-
compilation_cases_decode = sorted(self.aclgraph_batch_sizes)
4038+
max_num_tokens = self.scheduler_config.max_num_seqs * \
4039+
self.uniform_decode_query_len
4040+
decode_cudagraph_batch_sizes = [
4041+
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
4042+
and x >= self.uniform_decode_query_len
4043+
]
4044+
compilation_cases_decode = list(
4045+
reversed(decode_cudagraph_batch_sizes))
40394046
self._capture_aclgraphs(
40404047
compilation_cases=compilation_cases_decode,
40414048
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)