- 
                Notifications
    You must be signed in to change notification settings 
- Fork 528
Main mfix bug when max_seqs=14 in mtp=2 scenario and raise error when cudagraph_capture_sizes can't be an integer multiple of uniform_decode_query_lentp #3910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge: 
 If CI fails, you can run linting and testing checks locally according Contributing and Testing. | 
This reverts commit adadd50. Signed-off-by: zouyida2052 <[email protected]>
… of uniform_decode_query_len Signed-off-by: zouyida2052 <[email protected]>
Signed-off-by: zouyida2052 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several changes, including reverting a previous bugfix, adding a validation check for cudagraph_capture_sizes, and fixing a bug related to max_num_seqs. The changes in vllm_ascend/worker/model_runner_v1.py to raise an error for invalid cudagraph_capture_sizes are well-implemented with a clear error message. The refactoring in vllm_ascend/utils.py and vllm_ascend/torchair/torchair_model_runner.py improves code structure.
However, I've identified a critical issue in vllm_ascend/torchair/torchair_model_runner.py where mc2_tokens_capacity is calculated incorrectly, which could lead to insufficient memory allocation. Additionally, there's an opportunity to improve error handling for hardware limit violations by raising exceptions instead of just logging errors. Please see my detailed comments for suggestions.
| max_graph_batch_size = self.calculate_new_torchair_graph_batch_size( | ||
| self.max_num_reqs, tp_size) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a bug in the calculation of max_graph_batch_size. The function calculate_new_torchair_graph_batch_size expects a token count, but it's being called with self.max_num_reqs, which is a request count. This will lead to an incorrect and much smaller mc2_tokens_capacity. The previous implementation correctly calculated max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len. This logic should be restored before calling the new helper function.
| max_graph_batch_size = self.calculate_new_torchair_graph_batch_size( | |
| self.max_num_reqs, tp_size) | |
| max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len | |
| max_graph_batch_size = self.calculate_new_torchair_graph_batch_size( | |
| max_num_tokens, tp_size) | 
| if get_ascend_soc_version( | ||
| ) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512: | ||
| logger.error( | ||
| f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}" | ||
| ) | ||
| if get_ascend_soc_version( | ||
| ) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256: | ||
| logger.error( | ||
| f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}" | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using logger.error for violations of hard hardware limits might not be sufficient. An error log will be printed, but the execution will continue, potentially leading to more obscure failures later on. It would be better to raise a ValueError to halt execution immediately and provide a clear error message to the user. This also provides an opportunity to improve the error messages for clarity and grammatical correctness. Additionally, calling get_ascend_soc_version() once and storing it in a local variable would be more efficient.
        soc_version = get_ascend_soc_version()
        if soc_version == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
            raise ValueError(
                f"On Ascend A3, the max number of tokens for mc2 must be smaller than or equal to 512, but it is {self.mc2_tokens_capacity}"
            )
        if soc_version == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
            raise ValueError(
                f"On Ascend A2, the max number of tokens for mc2 must be smaller than or equal to 256, but it is {self.mc2_tokens_capacity}"
            )Signed-off-by: zouyida2052 <[email protected]>
What this PR does / why we need it?
Does this PR introduce any user-facing change?
no
How was this patch tested?