-
Notifications
You must be signed in to change notification settings - Fork 435
[main] mlp weight prefetch in Qwen Dense Models #2816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces MLP weight prefetching in Qwen Dense Models to optimize performance during the decode phase. The changes involve modifying the forward context, adding environment variables, and registering custom operations. The review focuses on identifying potential issues related to the added logic and ensuring the changes align with the intended optimization goals.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ | ||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
is repeated in multiple if
statements. Consider defining a variable to store this value and reuse it to avoid redundancy and improve readability. This also centralizes the configuration, making it easier to manage.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ | |
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | |
dense_optimize_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE | |
flashcomm_v1_enabled = dense_optimize_enabled and \ | |
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | |
tp_world_size > 1 and \ | |
num_tokens is not None and num_tokens > 1000 |
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ | ||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ | ||
forward_context.layer_idx is not None and \ | ||
num_tokens is not None and num_tokens < 500 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, the condition envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
is repeated here. Consider reusing the dense_optimize_enabled
variable defined earlier to maintain consistency and readability.
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ | |
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ | |
forward_context.layer_idx is not None and \ | |
num_tokens is not None and num_tokens < 500 | |
prefetch_mlp_enabled = dense_optimize_enabled and \ | |
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ | |
forward_context.layer_idx is not None and \ | |
num_tokens is not None and num_tokens < 500 |
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \ | ||
x_dependency, MLP_GATE_UP_PREFETCH_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a check to ensure that prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight
is a valid tensor before calling torch_npu.npu_prefetch
. This can prevent potential runtime errors if the weight is not properly initialized or is missing.
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \ | |
x_dependency, MLP_GATE_UP_PREFETCH_SIZE) | |
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE | |
weight = prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight | |
if isinstance(weight, torch.Tensor): | |
torch_npu.npu_prefetch(weight, x_dependency, MLP_GATE_UP_PREFETCH_SIZE) |
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \ | ||
x_dependency, MLP_DOWN_PREFETCH_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, add a check to ensure that prefetch_model.model.layers[layer_idx].mlp.down_proj.weight
is a valid tensor before calling torch_npu.npu_prefetch
. This can prevent potential runtime errors if the weight is not properly initialized or is missing.
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \ | |
x_dependency, MLP_DOWN_PREFETCH_SIZE) | |
with torch.npu.stream(prefetch_stream): | |
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE | |
weight = prefetch_model.model.layers[layer_idx].mlp.down_proj.weight | |
if isinstance(weight, torch.Tensor): | |
torch_npu.npu_prefetch(weight, x_dependency, MLP_DOWN_PREFETCH_SIZE) |
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
003689e
to
4e80b82
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
4e80b82
to
fa59c48
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
662a856
to
17b9ff4
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
17b9ff4
to
c29bfc9
Compare
b3b2d43
to
40247c6
Compare
max_model_len=8192, | ||
enforce_eager=enforce_eager, | ||
dtype="auto", | ||
tensor_parallel_size=4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we only have 2 cards on multicard ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed it to '2' now.
e284c9f
to
d03f050
Compare
70e4bab
to
de36051
Compare
3045dd6
to
79cbaee
Compare
Signed-off-by: rjg-lyh <[email protected]> Co-authored-by: Shuming19 <[email protected]>
79cbaee
to
bad76ac
Compare
Codecov Report❌ Patch coverage is ❌ Your patch status has failed because the patch coverage (42.00%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #2816 +/- ##
==========================================
+ Coverage 74.76% 75.12% +0.36%
==========================================
Files 150 155 +5
Lines 20891 21351 +460
==========================================
+ Hits 15620 16041 +421
- Misses 5271 5310 +39
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
I will add more ut in next week. |
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": | ||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), | ||
# buffer size for gate up prefetch | ||
"MLP_GATE_UP_PREFETCH_SIZE": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add VLLM_ASCEND_
prefix
"MLP_GATE_UP_PREFETCH_SIZE": | ||
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), | ||
# buffer size for down proj prefetch | ||
"MLP_DOWN_PREFETCH_SIZE": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
### What this PR does / why we need it? This PR prefetchs the weight of mlp layers in Qwen Dense Models to optimize the performance in Decode phase mainly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: main - vLLM main: vllm-project/vllm@a1213fa Signed-off-by: rjg-lyh <[email protected]> Co-authored-by: Shuming19 <[email protected]> Signed-off-by: Yizhou Liu <[email protected]>
if forward_context.prefetch_mlp_gate_up_proj: | ||
prefetch_stream.wait_stream(torch.npu.current_stream()) | ||
|
||
with torch.npu.stream(prefetch_stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can provide the performance data when the gbs increases and Stream becomes insufficient
What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to optimize the performance in Decode phase mainly.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
CI passed with new added/existing test.