Skip to content

Commit 662a856

Browse files
rjg-lyhShuming19
andcommitted
[main] mlp weight prefetch in Qwen Dense Models
Signed-off-by: rjg-lyh <[email protected]> Co-authored-by: Shuming19 <[email protected]>
1 parent 5bcb4c1 commit 662a856

14 files changed

+273
-18
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,4 @@ jobs:
258258
VLLM_WORKER_MULTIPROC_METHOD: spawn
259259
VLLM_USE_MODELSCOPE: True
260260
run: |
261-
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
261+
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ jobs:
226226
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
227227
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
228228
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
229+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
230+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
229231
230232
#pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
231233
pytest -sv tests/e2e/multicard/test_prefix_caching.py

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,23 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
173173
tensor_parallel_size=4,
174174
) as vllm_model:
175175
vllm_model.generate_greedy(example_prompts, max_tokens)
176+
177+
178+
@pytest.mark.parametrize("enforce_eager", [True, False])
179+
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
180+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
181+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
182+
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(model, enforce_eager):
183+
example_prompts = [
184+
"Hello, my name is",
185+
]
186+
max_tokens = 5
187+
188+
with VllmRunner(
189+
snapshot_download(model),
190+
max_model_len=8192,
191+
enforce_eager=enforce_eager,
192+
dtype="auto",
193+
tensor_parallel_size=4,
194+
) as vllm_model:
195+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/ascend_forward_context.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def set_ascend_forward_context(
6666
moe_comm_method: str = "",
6767
num_actual_tokens: Optional[int] = None,
6868
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
69-
batch_descriptor: Optional[BatchDescriptor] = None):
69+
batch_descriptor: Optional[BatchDescriptor] = None,
70+
prefetch_stream: torch.npu.Stream = None,
71+
prefetch_model: torch.nn.Module = None):
7072
"""A context manager that stores the current forward context,
7173
can be attention metadata, etc.
7274
We add some additional param into forward_context.
@@ -108,7 +110,8 @@ def set_ascend_forward_context(
108110
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
109111
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
110112
# the performance may degrade due to the switching of communication methods.
111-
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
113+
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
114+
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
112115
tp_world_size > 1 and \
113116
num_tokens is not None and num_tokens > 1000
114117

@@ -119,6 +122,26 @@ def set_ascend_forward_context(
119122

120123
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
121124

125+
# set layer_idx to enable optimization features that depend on this information.
126+
# This is only applicable to models that contain these necessary attributes.
127+
forward_context.layer_idx = None
128+
if prefetch_model is not None and \
129+
hasattr(prefetch_model, "model") and \
130+
hasattr(prefetch_model.model, "start_layer"):
131+
forward_context.layer_idx = prefetch_model.model.start_layer
132+
133+
# set for mlp weight prefetch
134+
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
135+
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
136+
forward_context.layer_idx is not None and \
137+
num_tokens is not None and num_tokens < 500
138+
if prefetch_mlp_enabled:
139+
forward_context.prefetch_stream = prefetch_stream
140+
forward_context.prefetch_model = prefetch_model
141+
forward_context.prefetch_mlp_gate_up_proj = False
142+
forward_context.prefetch_mlp_down_proj = False
143+
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
144+
122145
if num_tokens is None and attn_metadata is not None:
123146
num_tokens = attn_metadata.num_actual_tokens
124147

vllm_ascend/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@
135135
# This feature will get better performance when concurrency is large.
136136
"VLLM_ASCEND_ENABLE_FLASHCOMM":
137137
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
138+
# Whether to enable MLP weight prefetch, only used in small concurrency.
139+
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
140+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
141+
# buffer size for gate up prefetch
142+
"MLP_GATE_UP_PREFETCH_SIZE":
143+
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
144+
# buffer size for down proj prefetch
145+
"MLP_DOWN_PREFETCH_SIZE":
146+
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
138147
# Whether to enable dense model and general optimizations for better performance.
139148
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
140149
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.

vllm_ascend/ops/activation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3535

3636
from vllm_ascend.utils import is_310p
3737

38+
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
3839
if is_310p():
3940
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
4041
else:
4142
out = torch_npu.npu_swiglu(x)
43+
torch.ops.vllm.maybe_wait_prefetch_done(out)
4244
return out

vllm_ascend/ops/layernorm.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ def forward(
4444
import torch_npu
4545

4646
if residual is not None:
47-
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
48-
# is enabled, without interfering with the normal operation of components like torchair.
49-
# The final solution should be to move this check into the operator and support
50-
# integration with torchair.
51-
if x.size(0) != residual.size(0):
52-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
47+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
5348
assert x.size(0) == residual.size(0)
5449
x, _, residual = torch_npu.npu_add_rms_norm_quant(
5550
x,
@@ -58,6 +53,7 @@ def forward(
5853
self.layer.aclnn_input_scale,
5954
self.layer.aclnn_input_offset,
6055
epsilon=self.variance_epsilon)
56+
torch.ops.vllm.maybe_wait_prefetch_done(x)
6157
return x, residual
6258

6359
x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -76,12 +72,7 @@ def forward_oot(
7672

7773
from vllm_ascend.utils import is_310p
7874
if residual is not None:
79-
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
80-
# is enabled, without interfering with the normal operation of components like torchair.
81-
# The final solution should be to move this check into the operator and support
82-
# integration with torchair.
83-
if x.size(0) != residual.size(0):
84-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
75+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
8576
assert x.size(0) == residual.size(0)
8677
if is_310p():
8778
orig_dtype = residual.dtype
@@ -92,6 +83,7 @@ def forward_oot(
9283
else:
9384
x, _, residual = torch_npu.npu_add_rms_norm(
9485
x, residual, self.weight, self.variance_epsilon)
86+
torch.ops.vllm.maybe_wait_prefetch_done(x)
9587
return x, residual
9688

9789
x, residual = torch_npu.npu_rms_norm(x, self.weight,

vllm_ascend/ops/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def _forward_dense_optim(
361361
input_parallel,
362362
bias=bias_)
363363
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
364+
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
364365

365366
output_bias = self.bias if self.skip_bias_add else None
366367

vllm_ascend/ops/register_custom_ops.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
import torch
22
import torch.nn.functional as F
3+
import torch_npu
34
from vllm.distributed import (get_tensor_model_parallel_rank,
45
get_tensor_model_parallel_world_size,
56
tensor_model_parallel_all_gather,
67
tensor_model_parallel_all_reduce,
78
tensor_model_parallel_reduce_scatter)
89
from vllm.forward_context import get_forward_context
910
from vllm.utils import direct_register_custom_op
11+
import vllm_ascend.envs as envs_ascend
1012

1113

1214
def _maybe_chunk_residual_impl(x: torch.Tensor,
1315
residual: torch.Tensor) -> torch.Tensor:
14-
if get_forward_context().flashcomm_v1_enabled:
16+
if x.size(0) != residual.size(0):
17+
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
18+
assert flashcomm_v1_enabled is True, (
19+
"Currently, this situation only occurs "
20+
"when flashcomm_v1 is enabled"
21+
)
1522
pad_size = get_forward_context().pad_size
1623
if pad_size > 0:
1724
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -44,6 +51,75 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
4451
return tensor_model_parallel_all_reduce(x)
4552

4653

54+
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
55+
prefix: str) -> None:
56+
forward_context = get_forward_context()
57+
if not forward_context.prefetch_mlp_enabled:
58+
return
59+
prefetch_model = forward_context.prefetch_model
60+
prefetch_stream = forward_context.prefetch_stream
61+
layer_idx = int(prefix.split('.')[2])
62+
63+
# start point of gate_up_proj weight prefetch
64+
if prefix.split('.')[-2] == "self_attn":
65+
forward_context.prefetch_mlp_gate_up_proj = True
66+
if forward_context.prefetch_mlp_gate_up_proj:
67+
prefetch_stream.wait_stream(torch.npu.current_stream())
68+
69+
with torch.npu.stream(prefetch_stream):
70+
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
71+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \
72+
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
73+
return
74+
75+
76+
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
77+
prefix: str) -> None:
78+
return
79+
80+
81+
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
82+
forward_context = get_forward_context()
83+
if not forward_context.prefetch_mlp_enabled:
84+
return
85+
forward_context.prefetch_mlp_down_proj = True
86+
prefetch_model = forward_context.prefetch_model
87+
prefetch_stream = forward_context.prefetch_stream
88+
layer_idx = forward_context.layer_idx
89+
90+
# start point of down_proj weight prefetch
91+
prefetch_stream.wait_stream(torch.npu.current_stream())
92+
93+
with torch.npu.stream(prefetch_stream):
94+
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
95+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
96+
x_dependency, MLP_DOWN_PREFETCH_SIZE)
97+
forward_context.layer_idx += 1
98+
return
99+
100+
101+
def _maybe_prefetch_mlp_down_proj_impl_fake(x_dependency: torch.Tensor) -> None:
102+
return
103+
104+
105+
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
106+
forward_context = get_forward_context()
107+
if not forward_context.prefetch_mlp_enabled:
108+
return
109+
if forward_context.prefetch_mlp_gate_up_proj or \
110+
forward_context.prefetch_mlp_down_proj:
111+
prefetch_stream = get_forward_context().prefetch_stream
112+
# wait until prefetch done
113+
torch.npu.current_stream().wait_stream(prefetch_stream)
114+
forward_context.prefetch_mlp_gate_up_proj = False
115+
forward_context.prefetch_mlp_down_proj = False
116+
return
117+
118+
119+
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
120+
return
121+
122+
47123
direct_register_custom_op(op_name="maybe_chunk_residual",
48124
op_func=_maybe_chunk_residual_impl,
49125
fake_impl=lambda x, residual: residual,
@@ -60,4 +136,25 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
60136
op_func=_maybe_pad_and_reduce_impl,
61137
fake_impl=lambda x: x,
62138
mutates_args=[],
139+
dispatch_key="PrivateUse1")
140+
141+
142+
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
143+
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
144+
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
145+
mutates_args=[],
146+
dispatch_key="PrivateUse1")
147+
148+
149+
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
150+
op_func=_maybe_prefetch_mlp_down_proj_impl,
151+
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
152+
mutates_args=[],
153+
dispatch_key="PrivateUse1")
154+
155+
156+
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
157+
op_func=_maybe_wait_prefetch_done_impl,
158+
fake_impl=_maybe_wait_prefetch_done_impl_fake,
159+
mutates_args=[],
63160
dispatch_key="PrivateUse1")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#
17+
18+
import torch
19+
20+
21+
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
22+
"""AscendSiluAndMul forward in torchair mode.
23+
24+
The key difference from the original implementation is the removal of operators
25+
from the torch.ops.vllm class, as these operators only function in non-torchair
26+
modes. Adding them back would cause the graph compilation to fail.
27+
"""
28+
29+
import torch_npu
30+
31+
from vllm_ascend.utils import is_310p
32+
33+
if is_310p():
34+
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
35+
else:
36+
out = torch_npu.npu_swiglu(x)
37+
return out

0 commit comments

Comments
 (0)