Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
update_aclgraph_sizes(vllm_config)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
Expand Down Expand Up @@ -344,7 +343,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
update_aclgraph_sizes(vllm_config)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
Expand Down
39 changes: 28 additions & 11 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,23 @@ def _init_mc2_tokens_capacity(self):
# NOTE: To be clear, we need to make sure that during graph capture, the number of
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
max_num_tokens = self.parallel_config.tensor_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
# maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
max_num_tokens, tp_size)
self.mc2_tokens_capacity = max_graph_batch_size

if get_ascend_soc_version(
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
logger.error(
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
)
if get_ascend_soc_version(
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
logger.error(
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
)
Comment on lines +127 to +136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using logger.error for violations of hard hardware limits might not be sufficient. An error log will be printed, but the execution will continue, potentially leading to more obscure failures later on. It would be better to raise a ValueError to halt execution immediately and provide a clear error message to the user. This also provides an opportunity to improve the error messages for clarity and grammatical correctness. Additionally, calling get_ascend_soc_version() once and storing it in a local variable would be more efficient.

        soc_version = get_ascend_soc_version()
        if soc_version == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
            raise ValueError(
                f"On Ascend A3, the max number of tokens for mc2 must be smaller than or equal to 512, but it is {self.mc2_tokens_capacity}"
            )
        if soc_version == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
            raise ValueError(
                f"On Ascend A2, the max number of tokens for mc2 must be smaller than or equal to 256, but it is {self.mc2_tokens_capacity}"
            )


def _sync_metadata_across_dp(
self, num_tokens: int,
Expand Down Expand Up @@ -464,6 +475,17 @@ def init_torchair_graph_batch_sizes(self):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2

def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
tp_size):
cur_graph_batch_size = (old_graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
// math.gcd(tp_size, old_graph_batch_size)
return cur_graph_batch_size

def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
Expand Down Expand Up @@ -515,13 +537,8 @@ def _align_graph_size_divisible_by_tp_size(self):
tp_size = self.parallel_config.tensor_parallel_size
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
# Both adapter multi-dp and FIA operator
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
cur_graph_batch_size = (tp_size * graph_batch_size) \
// math.gcd(tp_size, graph_batch_size)
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
graph_batch_size, tp_size)
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
Expand Down
65 changes: 18 additions & 47 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,6 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,

def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
"""Update ACL graph capture sizes based on hardware limitations"""
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
return
# NOTE: Currently, we can only capture 1800 graphs at most,
# due to the limitation of ACL graph. This number is bounded by
# the number of streams, which is 2048, we save 248 streams
Expand Down Expand Up @@ -465,51 +459,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
vllm_config.model_config.architectures[0], num_hidden_layers,
len(original_sizes))

if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)


def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
from vllm.config.compilation import CUDAGraphMode
compilation_config = vllm_config.compilation_config
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
uniform_decode_query_len = num_speculative_tokens + 1
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_tokens = max_num_seqs * uniform_decode_query_len
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None

assert len(original_sizes) > 0

if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
if size >= uniform_decode_query_len and size *
uniform_decode_query_len <= max_num_tokens
]
if vllm_version_is("0.11.0"):
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
else:
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
original_sizes, enlarged_sizes)
elif original_sizes[0] < max_num_tokens:
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
]
if vllm_version_is("0.11.0"):
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
assert len(original_sizes) > 0
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
enlarged_sizes = [(num_speculative_tokens + 1) * size
for size in original_sizes]
if vllm_version_is("0.11.0"):
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
else:
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
logger.info(
"Adjusted ACL graphs: %s → %s for speculative decoding",
original_sizes, enlarged_sizes)
else:
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
compilation_config.cudagraph_capture_sizes = original_sizes


# TODO(wxy): Move to ops module
Expand Down
18 changes: 17 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,7 +4035,23 @@ def _capture_model(self):

if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
compilation_cases_decode = sorted(self.aclgraph_batch_sizes)
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
# TODO: refactor this when vLLM supports mtp>1
if not all(x % self.uniform_decode_query_len == 0
for x in decode_cudagraph_batch_sizes):
raise ValueError(
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
)
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
Expand Down
Loading