Skip to content

Commit 40247c6

Browse files
rjg-lyhShuming19
andcommitted
[main] mlp weight prefetch in Qwen Dense Models
Signed-off-by: rjg-lyh <[email protected]> Co-authored-by: Shuming19 <[email protected]>
1 parent b7ee3fd commit 40247c6

14 files changed

+273
-19
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,4 @@ jobs:
258258
VLLM_WORKER_MULTIPROC_METHOD: spawn
259259
VLLM_USE_MODELSCOPE: True
260260
run: |
261-
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
261+
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ jobs:
226226
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
227227
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
228228
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
229+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
230+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
229231
230232
#pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
231233
pytest -sv tests/e2e/multicard/test_prefix_caching.py

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,24 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
173173
tensor_parallel_size=4,
174174
) as vllm_model:
175175
vllm_model.generate_greedy(example_prompts, max_tokens)
176+
177+
178+
@pytest.mark.parametrize("enforce_eager", [True, False])
179+
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
180+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
181+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
182+
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(
183+
model, enforce_eager):
184+
example_prompts = [
185+
"Hello, my name is",
186+
]
187+
max_tokens = 5
188+
189+
with VllmRunner(
190+
snapshot_download(model),
191+
max_model_len=8192,
192+
enforce_eager=enforce_eager,
193+
dtype="auto",
194+
tensor_parallel_size=4,
195+
) as vllm_model:
196+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/ascend_forward_context.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def set_ascend_forward_context(
6666
moe_comm_method: str = "",
6767
num_actual_tokens: Optional[int] = None,
6868
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
69-
batch_descriptor: Optional[BatchDescriptor] = None):
69+
batch_descriptor: Optional[BatchDescriptor] = None,
70+
prefetch_stream: torch.npu.Stream = None,
71+
prefetch_model: torch.nn.Module = None):
7072
"""A context manager that stores the current forward context,
7173
can be attention metadata, etc.
7274
We add some additional param into forward_context.
@@ -108,7 +110,8 @@ def set_ascend_forward_context(
108110
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
109111
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
110112
# the performance may degrade due to the switching of communication methods.
111-
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
113+
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
114+
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
112115
tp_world_size > 1 and \
113116
num_tokens is not None and num_tokens > 1000
114117

@@ -122,6 +125,26 @@ def set_ascend_forward_context(
122125
# set this for rope forward_oot using
123126
forward_context.is_first_layer = True
124127

128+
# set layer_idx to enable optimization features that depend on this information.
129+
# This is only applicable to models that contain these necessary attributes.
130+
forward_context.layer_idx = None
131+
if prefetch_model is not None and \
132+
hasattr(prefetch_model, "model") and \
133+
hasattr(prefetch_model.model, "start_layer"):
134+
forward_context.layer_idx = prefetch_model.model.start_layer
135+
136+
# set for mlp weight prefetch
137+
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
138+
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
139+
forward_context.layer_idx is not None and \
140+
num_tokens is not None and num_tokens < 500
141+
if prefetch_mlp_enabled:
142+
forward_context.prefetch_stream = prefetch_stream
143+
forward_context.prefetch_model = prefetch_model
144+
forward_context.prefetch_mlp_gate_up_proj = False
145+
forward_context.prefetch_mlp_down_proj = False
146+
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
147+
125148
if num_tokens is None and attn_metadata is not None:
126149
num_tokens = attn_metadata.num_actual_tokens
127150

vllm_ascend/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@
135135
# This feature will get better performance when concurrency is large.
136136
"VLLM_ASCEND_ENABLE_FLASHCOMM":
137137
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
138+
# Whether to enable MLP weight prefetch, only used in small concurrency.
139+
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
140+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
141+
# buffer size for gate up prefetch
142+
"MLP_GATE_UP_PREFETCH_SIZE":
143+
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
144+
# buffer size for down proj prefetch
145+
"MLP_DOWN_PREFETCH_SIZE":
146+
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
138147
# Whether to enable dense model and general optimizations for better performance.
139148
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
140149
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.

vllm_ascend/ops/activation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3535

3636
from vllm_ascend.utils import is_310p
3737

38+
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
3839
if is_310p():
3940
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
4041
else:
4142
out = torch_npu.npu_swiglu(x)
43+
torch.ops.vllm.maybe_wait_prefetch_done(out)
4244
return out

vllm_ascend/ops/layernorm.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ def forward(
4444
import torch_npu
4545

4646
if residual is not None:
47-
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
48-
# is enabled, without interfering with the normal operation of components like torchair.
49-
# The final solution should be to move this check into the operator and support
50-
# integration with torchair.
51-
if x.size(0) != residual.size(0):
52-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
47+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
5348
assert x.size(0) == residual.size(0)
5449
x, _, residual = torch_npu.npu_add_rms_norm_quant(
5550
x,
@@ -58,6 +53,7 @@ def forward(
5853
self.layer.aclnn_input_scale,
5954
self.layer.aclnn_input_offset,
6055
epsilon=self.variance_epsilon)
56+
torch.ops.vllm.maybe_wait_prefetch_done(x)
6157
return x, residual
6258

6359
x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -76,12 +72,7 @@ def forward_oot(
7672

7773
from vllm_ascend.utils import is_310p
7874
if residual is not None:
79-
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
80-
# is enabled, without interfering with the normal operation of components like torchair.
81-
# The final solution should be to move this check into the operator and support
82-
# integration with torchair.
83-
if x.size(0) != residual.size(0):
84-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
75+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
8576
assert x.size(0) == residual.size(0)
8677
if is_310p():
8778
orig_dtype = residual.dtype
@@ -92,6 +83,7 @@ def forward_oot(
9283
else:
9384
x, _, residual = torch_npu.npu_add_rms_norm(
9485
x, residual, self.weight, self.variance_epsilon)
86+
torch.ops.vllm.maybe_wait_prefetch_done(x)
9587
return x, residual
9688

9789
x, residual = torch_npu.npu_rms_norm(x, self.weight,

vllm_ascend/ops/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def _forward_dense_optim(
366366
input_parallel,
367367
bias=bias_)
368368
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
369+
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
369370

370371
output_bias = self.bias if self.skip_bias_add else None
371372

vllm_ascend/ops/register_custom_ops.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn.functional as F
3+
import torch_npu
34
from vllm.distributed import (get_tensor_model_parallel_rank,
45
get_tensor_model_parallel_world_size,
56
tensor_model_parallel_all_gather,
@@ -8,10 +9,16 @@
89
from vllm.forward_context import get_forward_context
910
from vllm.utils import direct_register_custom_op
1011

12+
import vllm_ascend.envs as envs_ascend
13+
1114

1215
def _maybe_chunk_residual_impl(x: torch.Tensor,
1316
residual: torch.Tensor) -> torch.Tensor:
14-
if get_forward_context().flashcomm_v1_enabled:
17+
if x.size(0) != residual.size(0):
18+
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
19+
assert flashcomm_v1_enabled is True, (
20+
"Currently, this situation only occurs "
21+
"when flashcomm_v1 is enabled")
1522
pad_size = get_forward_context().pad_size
1623
if pad_size > 0:
1724
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -44,6 +51,76 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
4451
return tensor_model_parallel_all_reduce(x)
4552

4653

54+
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
55+
prefix: str) -> None:
56+
forward_context = get_forward_context()
57+
if not forward_context.prefetch_mlp_enabled:
58+
return
59+
prefetch_model = forward_context.prefetch_model
60+
prefetch_stream = forward_context.prefetch_stream
61+
layer_idx = int(prefix.split('.')[2])
62+
63+
# start point of gate_up_proj weight prefetch
64+
if prefix.split('.')[-2] == "self_attn":
65+
forward_context.prefetch_mlp_gate_up_proj = True
66+
if forward_context.prefetch_mlp_gate_up_proj:
67+
prefetch_stream.wait_stream(torch.npu.current_stream())
68+
69+
with torch.npu.stream(prefetch_stream):
70+
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
71+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \
72+
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
73+
return
74+
75+
76+
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
77+
prefix: str) -> None:
78+
return
79+
80+
81+
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
82+
forward_context = get_forward_context()
83+
if not forward_context.prefetch_mlp_enabled:
84+
return
85+
forward_context.prefetch_mlp_down_proj = True
86+
prefetch_model = forward_context.prefetch_model
87+
prefetch_stream = forward_context.prefetch_stream
88+
layer_idx = forward_context.layer_idx
89+
90+
# start point of down_proj weight prefetch
91+
prefetch_stream.wait_stream(torch.npu.current_stream())
92+
93+
with torch.npu.stream(prefetch_stream):
94+
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
95+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
96+
x_dependency, MLP_DOWN_PREFETCH_SIZE)
97+
forward_context.layer_idx += 1
98+
return
99+
100+
101+
def _maybe_prefetch_mlp_down_proj_impl_fake(
102+
x_dependency: torch.Tensor) -> None:
103+
return
104+
105+
106+
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
107+
forward_context = get_forward_context()
108+
if not forward_context.prefetch_mlp_enabled:
109+
return
110+
if forward_context.prefetch_mlp_gate_up_proj or \
111+
forward_context.prefetch_mlp_down_proj:
112+
prefetch_stream = get_forward_context().prefetch_stream
113+
# wait until prefetch done
114+
torch.npu.current_stream().wait_stream(prefetch_stream)
115+
forward_context.prefetch_mlp_gate_up_proj = False
116+
forward_context.prefetch_mlp_down_proj = False
117+
return
118+
119+
120+
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
121+
return
122+
123+
47124
direct_register_custom_op(op_name="maybe_chunk_residual",
48125
op_func=_maybe_chunk_residual_impl,
49126
fake_impl=lambda x, residual: residual,
@@ -60,4 +137,22 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
60137
op_func=_maybe_pad_and_reduce_impl,
61138
fake_impl=lambda x: x,
62139
mutates_args=[],
63-
dispatch_key="PrivateUse1")
140+
dispatch_key="PrivateUse1")
141+
142+
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
143+
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
144+
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
145+
mutates_args=[],
146+
dispatch_key="PrivateUse1")
147+
148+
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
149+
op_func=_maybe_prefetch_mlp_down_proj_impl,
150+
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
151+
mutates_args=[],
152+
dispatch_key="PrivateUse1")
153+
154+
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
155+
op_func=_maybe_wait_prefetch_done_impl,
156+
fake_impl=_maybe_wait_prefetch_done_impl_fake,
157+
mutates_args=[],
158+
dispatch_key="PrivateUse1")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#
17+
18+
import torch
19+
20+
21+
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
22+
"""AscendSiluAndMul forward in torchair mode.
23+
24+
The key difference from the original implementation is the removal of operators
25+
from the torch.ops.vllm class, as these operators only function in non-torchair
26+
modes. Adding them back would cause the graph compilation to fail.
27+
"""
28+
29+
import torch_npu
30+
31+
from vllm_ascend.utils import is_310p
32+
33+
if is_310p():
34+
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
35+
else:
36+
out = torch_npu.npu_swiglu(x)
37+
return out

0 commit comments

Comments
 (0)