Skip to content

Commit bad76ac

Browse files
rjg-lyhShuming19
andcommitted
[main] mlp weight prefetch in Qwen Dense Models
Signed-off-by: rjg-lyh <[email protected]> Co-authored-by: Shuming19 <[email protected]>
1 parent bd3dede commit bad76ac

17 files changed

+313
-24
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,4 @@ jobs:
258258
VLLM_WORKER_MULTIPROC_METHOD: spawn
259259
VLLM_USE_MODELSCOPE: True
260260
run: |
261-
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
261+
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ jobs:
226226
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
227227
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
228228
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
229+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
230+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
229231
230232
#pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
231233
pytest -sv tests/e2e/multicard/test_prefix_caching.py

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3333

34-
QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]
34+
QWEN_DENSE_MODELS = [
35+
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
36+
]
3537

3638

3739
def test_models_distributed_QwQ():
@@ -170,6 +172,29 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
170172
max_model_len=8192,
171173
enforce_eager=enforce_eager,
172174
dtype="auto",
173-
tensor_parallel_size=4,
175+
tensor_parallel_size=2,
176+
quantization="ascend",
177+
) as vllm_model:
178+
vllm_model.generate_greedy(example_prompts, max_tokens)
179+
180+
181+
@pytest.mark.parametrize("enforce_eager", [True, False])
182+
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
183+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
184+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
185+
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(
186+
model, enforce_eager):
187+
example_prompts = [
188+
"Hello, my name is",
189+
]
190+
max_tokens = 5
191+
192+
with VllmRunner(
193+
snapshot_download(model),
194+
max_model_len=8192,
195+
enforce_eager=enforce_eager,
196+
dtype="auto",
197+
tensor_parallel_size=2,
198+
quantization="ascend",
174199
) as vllm_model:
175200
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/ut/ops/test_activation.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
3838

3939
@pytest.mark.parametrize("is_310p_return", [True, False])
4040
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
41-
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
41+
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
42+
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
43+
side_effect=lambda x: None)
44+
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
45+
mock_maybe_wait_prefetch_done, mock_swiglu,
46+
is_310p_return, dummy_tensor):
4247

4348
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
4449
layer = SiluAndMul()
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
4954
else:
5055
expected_arg = dummy_tensor
5156

57+
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
58+
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
59+
5260
# assert mock_swiglu.call_count == 1
5361
mock_swiglu.assert_called_once()
5462

63+
# assert mock_maybe_wait_prefetch_done.call_count == 1
64+
mock_maybe_wait_prefetch_done.assert_called_once()
65+
5566
actual_arg = mock_swiglu.call_args[0][0]
5667
assert torch.allclose(
5768
actual_arg,

tests/ut/ops/test_layernorm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
3030
[None, torch.randn(4, 8, dtype=torch.float32)])
3131
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
3232
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
33+
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
3334
@patch("torch.ops.vllm.maybe_chunk_residual",
3435
side_effect=mock_maybe_chunk_residual)
35-
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
36+
def test_RMSNorm_forward(mock_maybe_chunk_residual,
37+
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
3638
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
3739

3840
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
@@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
4547
expected_out_x = expected_arg_x + 1
4648
expected_out_residual = expected_arg_x.to(residual.dtype)
4749

50+
mock_maybe_chunk_residual.assert_called_once()
4851
mock_rmsnorm.assert_called_once()
52+
mock_maybe_wait_prefetch_done.assert_called_once()
4953
assert torch.allclose(out_x, expected_out_x)
5054
assert torch.allclose(out_residual, expected_out_residual)
5155
else:
5256
expected_out_x = 2 * dummy_tensor
5357
expected_out_residual = 2 * residual
58+
mock_maybe_chunk_residual.assert_called_once()
5459
mock_add_rmsnorm.assert_called_once()
60+
mock_maybe_wait_prefetch_done.assert_called_once()
5561
assert torch.allclose(out_x, expected_out_x)
5662
assert torch.allclose(out_residual, expected_out_residual)
5763
else:
@@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
6470

6571
@patch("vllm_ascend.utils.is_310p", return_value=False)
6672
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
73+
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
6774
@patch("torch.ops.vllm.maybe_chunk_residual",
6875
side_effect=mock_maybe_chunk_residual)
6976
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
77+
mock_maybe_wait_prefetch_done,
7078
mock_add_rms_norm, mock_is310p):
7179
x = torch.randn(4, 512, dtype=torch.bfloat16)
7280
residual = torch.randn(16, 512, dtype=torch.bfloat16)
@@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
7987

8088
mock_maybe_chunk_residual.assert_called_once()
8189
mock_add_rms_norm.assert_called_once()
90+
mock_maybe_wait_prefetch_done.assert_called_once()
8291
assert out_residual.size(0) == 4
8392
assert torch.allclose(out_x, expected_out_x)
8493
assert torch.allclose(out_residual, expected_out_residual)

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
275275

276276
@patch("torch_npu.npu_add_rms_norm")
277277
@patch("torch_npu.npu_rms_norm")
278-
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
278+
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
279+
@patch("torch.ops.vllm.maybe_chunk_residual",
280+
side_effect=lambda x, residual: residual)
281+
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
282+
mock_maybe_wait_prefetch_done,
283+
mock_rms_norm, mock_add_norm,
279284
mock_distributed, base_config,
280285
vllm_config):
281286
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))

vllm_ascend/ascend_forward_context.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def set_ascend_forward_context(
6666
moe_comm_method: str = "",
6767
num_actual_tokens: Optional[int] = None,
6868
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
69-
batch_descriptor: Optional[BatchDescriptor] = None):
69+
batch_descriptor: Optional[BatchDescriptor] = None,
70+
prefetch_stream: torch.npu.Stream = None,
71+
model_instance: torch.nn.Module = None):
7072
"""A context manager that stores the current forward context,
7173
can be attention metadata, etc.
7274
We add some additional param into forward_context.
@@ -108,7 +110,8 @@ def set_ascend_forward_context(
108110
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
109111
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
110112
# the performance may degrade due to the switching of communication methods.
111-
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
113+
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
114+
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
112115
tp_world_size > 1 and \
113116
num_tokens is not None and num_tokens > 1000
114117

@@ -122,6 +125,26 @@ def set_ascend_forward_context(
122125
# set this for rope forward_oot using
123126
forward_context.is_first_layer = True
124127

128+
# set layer_idx to enable optimization features that depend on this information.
129+
# This is only applicable to models that contain these necessary attributes.
130+
forward_context.layer_idx = None
131+
if model_instance is not None and \
132+
hasattr(model_instance, "model") and \
133+
hasattr(model_instance.model, "start_layer"):
134+
forward_context.layer_idx = model_instance.model.start_layer
135+
136+
# set for mlp weight prefetch
137+
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
138+
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
139+
forward_context.layer_idx is not None and \
140+
num_tokens is not None and num_tokens < 500
141+
if prefetch_mlp_enabled:
142+
forward_context.prefetch_stream = prefetch_stream
143+
forward_context.model_instance = model_instance
144+
forward_context.prefetch_mlp_gate_up_proj = False
145+
forward_context.prefetch_mlp_down_proj = False
146+
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
147+
125148
if num_tokens is None and attn_metadata is not None:
126149
num_tokens = attn_metadata.num_actual_tokens
127150

vllm_ascend/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@
135135
# This feature will get better performance when concurrency is large.
136136
"VLLM_ASCEND_ENABLE_FLASHCOMM":
137137
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
138+
# Whether to enable MLP weight prefetch, only used in small concurrency.
139+
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
140+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
141+
# buffer size for gate up prefetch
142+
"MLP_GATE_UP_PREFETCH_SIZE":
143+
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
144+
# buffer size for down proj prefetch
145+
"MLP_DOWN_PREFETCH_SIZE":
146+
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
138147
# Whether to enable dense model and general optimizations for better performance.
139148
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
140149
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.

vllm_ascend/ops/activation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3535

3636
from vllm_ascend.utils import is_310p
3737

38+
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
3839
if is_310p():
3940
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
4041
else:
4142
out = torch_npu.npu_swiglu(x)
43+
torch.ops.vllm.maybe_wait_prefetch_done(out)
4244
return out

vllm_ascend/ops/layernorm.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ def forward(
4444
import torch_npu
4545

4646
if residual is not None:
47-
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
48-
# is enabled, without interfering with the normal operation of components like torchair.
49-
# The final solution should be to move this check into the operator and support
50-
# integration with torchair.
51-
if x.size(0) != residual.size(0):
52-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
47+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
5348
assert x.size(0) == residual.size(0)
5449
x, _, residual = torch_npu.npu_add_rms_norm_quant(
5550
x,
@@ -58,6 +53,7 @@ def forward(
5853
self.layer.aclnn_input_scale,
5954
self.layer.aclnn_input_offset,
6055
epsilon=self.variance_epsilon)
56+
torch.ops.vllm.maybe_wait_prefetch_done(x)
6157
return x, residual
6258

6359
x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -76,12 +72,7 @@ def forward_oot(
7672

7773
from vllm_ascend.utils import is_310p
7874
if residual is not None:
79-
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
80-
# is enabled, without interfering with the normal operation of components like torchair.
81-
# The final solution should be to move this check into the operator and support
82-
# integration with torchair.
83-
if x.size(0) != residual.size(0):
84-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
75+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
8576
assert x.size(0) == residual.size(0)
8677
if is_310p():
8778
orig_dtype = residual.dtype
@@ -92,6 +83,7 @@ def forward_oot(
9283
else:
9384
x, _, residual = torch_npu.npu_add_rms_norm(
9485
x, residual, self.weight, self.variance_epsilon)
86+
torch.ops.vllm.maybe_wait_prefetch_done(x)
9587
return x, residual
9688

9789
x, residual = torch_npu.npu_rms_norm(x, self.weight,

0 commit comments

Comments
 (0)