Skip to content

Conversation

rjg-lyh
Copy link
Collaborator

@rjg-lyh rjg-lyh commented Sep 8, 2025

What this PR does / why we need it?

This PR prefetchs the weight of mlp layers in Qwen Dense Models to optimize the performance in Decode phase mainly.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

CI passed with new added/existing test.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces MLP weight prefetching in Qwen Dense Models to optimize performance during the decode phase. The changes involve modifying the forward context, adding environment variables, and registering custom operations. The review focuses on identifying potential issues related to the added logic and ensuring the changes align with the intended optimization goals.

Comment on lines +113 to +114
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE is repeated in multiple if statements. Consider defining a variable to store this value and reuse it to avoid redundancy and improve readability. This also centralizes the configuration, making it easier to manage.

Suggested change
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
dense_optimize_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
flashcomm_v1_enabled = dense_optimize_enabled and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000

Comment on lines +134 to +140
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, the condition envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE is repeated here. Consider reusing the dense_optimize_enabled variable defined earlier to maintain consistency and readability.

Suggested change
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
prefetch_mlp_enabled = dense_optimize_enabled and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500

Comment on lines 66 to 72
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consider adding a check to ensure that prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight is a valid tensor before calling torch_npu.npu_prefetch. This can prevent potential runtime errors if the weight is not properly initialized or is missing.

Suggested change
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
weight = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight
if isinstance(weight, torch.Tensor):
torch_npu.npu_prefetch(weight, x_dependency, MLP_GATE_UP_PREFETCH_SIZE)

Comment on lines 90 to 96
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, MLP_DOWN_PREFETCH_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, add a check to ensure that prefetch_model.model.layers[layer_idx].mlp.down_proj.weight is a valid tensor before calling torch_npu.npu_prefetch. This can prevent potential runtime errors if the weight is not properly initialized or is missing.

Suggested change
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, MLP_DOWN_PREFETCH_SIZE)
with torch.npu.stream(prefetch_stream):
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
weight = prefetch_model.model.layers[layer_idx].mlp.down_proj.weight
if isinstance(weight, torch.Tensor):
torch_npu.npu_prefetch(weight, x_dependency, MLP_DOWN_PREFETCH_SIZE)

Copy link

github-actions bot commented Sep 8, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link

github-actions bot commented Sep 9, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link

github-actions bot commented Sep 9, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@rjg-lyh rjg-lyh force-pushed the pr-prefetch-new branch 3 times, most recently from b3b2d43 to 40247c6 Compare September 10, 2025 13:44
max_model_len=8192,
enforce_eager=enforce_eager,
dtype="auto",
tensor_parallel_size=4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only have 2 cards on multicard ci

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it to '2' now.

@rjg-lyh rjg-lyh force-pushed the pr-prefetch-new branch 3 times, most recently from e284c9f to d03f050 Compare September 11, 2025 01:56
@MengqingCao MengqingCao added the ready read for review label Sep 11, 2025
@rjg-lyh rjg-lyh force-pushed the pr-prefetch-new branch 2 times, most recently from 70e4bab to de36051 Compare September 11, 2025 06:36
@rjg-lyh rjg-lyh force-pushed the pr-prefetch-new branch 3 times, most recently from 3045dd6 to 79cbaee Compare September 11, 2025 09:03
Copy link

codecov bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 42.00000% with 58 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.12%. Comparing base (1bbb20e) to head (bad76ac).
⚠️ Report is 22 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/ops/register_custom_ops.py 20.75% 42 Missing ⚠️
vllm_ascend/ascend_forward_context.py 54.54% 5 Missing ⚠️
vllm_ascend/torchair/utils.py 0.00% 5 Missing ⚠️
vllm_ascend/worker/model_runner_v1.py 0.00% 3 Missing ⚠️
vllm_ascend/ops/layernorm.py 50.00% 2 Missing ⚠️
vllm_ascend/ops/linear.py 0.00% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (42.00%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2816      +/-   ##
==========================================
+ Coverage   74.76%   75.12%   +0.36%     
==========================================
  Files         150      155       +5     
  Lines       20891    21351     +460     
==========================================
+ Hits        15620    16041     +421     
- Misses       5271     5310      +39     
Flag Coverage Δ
unittests 75.12% <42.00%> (+0.36%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wangxiyuan wangxiyuan added the ready-for-test start test by label for PR label Sep 11, 2025
@rjg-lyh
Copy link
Collaborator Author

rjg-lyh commented Sep 11, 2025

I will add more ut in next week.

"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
# buffer size for gate up prefetch
"MLP_GATE_UP_PREFETCH_SIZE":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add VLLM_ASCEND_ prefix

"MLP_GATE_UP_PREFETCH_SIZE":
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
# buffer size for down proj prefetch
"MLP_DOWN_PREFETCH_SIZE":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@wangxiyuan wangxiyuan merged commit 0005479 into vllm-project:main Sep 11, 2025
30 of 31 checks passed
yiz-liu pushed a commit to linfeng-yuan/vllm-ascend that referenced this pull request Sep 12, 2025
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
vllm-project/vllm@a1213fa

Signed-off-by: rjg-lyh <[email protected]>
Co-authored-by: Shuming19 <[email protected]>
Signed-off-by: Yizhou Liu <[email protected]>
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())

with torch.npu.stream(prefetch_stream):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can provide the performance data when the gbs increases and Stream becomes insufficient

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants