Skip to content

Commit 4547ae7

Browse files
committed
feat: implement matmul_allreduce in prefill phase when tensor parallel is enabled
Signed-off-by: Ronald1995 <[email protected]>
1 parent e561a2c commit 4547ae7

File tree

6 files changed

+334
-4
lines changed

6 files changed

+334
-4
lines changed

tests/ut/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919

2020
from vllm_ascend.utils import adapt_patch, register_ascend_customop
2121

22-
# fused moe ops test will hit the infer_schema error, we need add the patch
23-
# here to make the test pass.
24-
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
25-
2622

2723
class TestBase(unittest.TestCase):
2824

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from importlib import reload
2+
3+
import pytest
4+
import torch
5+
import vllm
6+
from pytest_mock import MockerFixture
7+
from vllm.model_executor.layers.linear import RowParallelLinear
8+
9+
from tests.ut.base import PytestBase
10+
from vllm_ascend import envs
11+
from vllm_ascend.patch.worker.patch_common import patch_linear
12+
13+
14+
class TestAscendRowParallelLinear(PytestBase):
15+
16+
def init_row_parallel_linear(self, mocker: MockerFixture):
17+
mocker.patch(
18+
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
19+
return_value=None,
20+
)
21+
mocker.patch("torch.nn.Module.__setattr__")
22+
mocker.patch("torch.nn.Module.__getattr__")
23+
mocker.patch("torch.nn.Module.__delattr__")
24+
return patch_linear.AscendRowParallelLinear(
25+
input_size=128,
26+
output_size=256,
27+
)
28+
29+
@pytest.mark.parametrize(
30+
"version, expected",
31+
[
32+
("1.0.0", 1),
33+
("2.1.0", 0),
34+
],
35+
)
36+
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
37+
mock_group = mocker.MagicMock()
38+
backend = mocker.MagicMock()
39+
backend.get_hccl_comm_name = lambda x: x
40+
mock_group._get_backend = lambda x: backend
41+
mock_group.get_hccl_comm_name = lambda x: x
42+
mocker.patch("torch.distributed.get_rank", return_value=1)
43+
mocker_func = mocker.patch(
44+
"torch.distributed.get_global_rank",
45+
return_value=0,
46+
)
47+
mocker.patch("torch.__version__", new=version)
48+
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
49+
mock_group)
50+
assert hcomm_info == expected
51+
mocker_func.assert_called_once()
52+
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
53+
mock_group)
54+
mocker_func.assert_not_called()
55+
assert hcomm_info == expected
56+
57+
@pytest.mark.parametrize(
58+
"skip_bias_add, return_bias, bias, expected",
59+
[
60+
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
61+
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
62+
(
63+
True,
64+
True,
65+
torch.tensor(1.0),
66+
(torch.tensor(14.0), torch.tensor(1.0)),
67+
),
68+
],
69+
)
70+
def test_forward(
71+
self,
72+
skip_bias_add,
73+
return_bias,
74+
bias,
75+
expected,
76+
mocker: MockerFixture,
77+
):
78+
mocker_tp_group = mocker.MagicMock()
79+
mocker_tp_group.device_group = mocker.MagicMock()
80+
row_parallel_linear = self.init_row_parallel_linear(mocker)
81+
row_parallel_linear.__dict__["tp_rank"] = 0
82+
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
83+
row_parallel_linear.__dict__["return_bias"] = return_bias
84+
row_parallel_linear.__dict__["bias"] = bias
85+
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
86+
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
87+
row_parallel_linear.__dict__[
88+
"calc_output"] = lambda x: x.matmul( # noqa
89+
torch.tensor([1.0, 2.0]))
90+
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
91+
if isinstance(ret, tuple):
92+
assert torch.allclose(ret[0], expected[0])
93+
if ret[1] is None:
94+
assert ret[1] == expected[1]
95+
else:
96+
assert torch.allclose(ret[1], expected[1])
97+
else:
98+
assert torch.allclose(ret, expected)
99+
100+
@pytest.mark.parametrize(
101+
"input_is_parallel, expected",
102+
[
103+
(True, torch.tensor([10.0, 2.0])),
104+
(False, torch.tensor([10.0])),
105+
],
106+
)
107+
def test_calc_input(
108+
self,
109+
input_is_parallel,
110+
expected,
111+
mocker: MockerFixture,
112+
):
113+
row_parallel_linear = self.init_row_parallel_linear(mocker)
114+
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
115+
input_tensor = torch.Tensor([10, 2])
116+
mocker.patch(
117+
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
118+
return_value=0,
119+
)
120+
mocker.patch(
121+
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
122+
return_value=[torch.Tensor([10]),
123+
torch.Tensor([2])],
124+
)
125+
input_parallel = row_parallel_linear.calc_input(input_tensor)
126+
assert torch.allclose(input_parallel, expected)
127+
128+
@pytest.mark.parametrize(
129+
"reduce_results, tp_size, is_prefill, expected",
130+
[
131+
(True, 2, True, torch.tensor(56.0)),
132+
(True, 2, False, torch.tensor(28.0)),
133+
(True, 1, False, torch.tensor(14.0)),
134+
(False, 2, False, torch.tensor(14.0)),
135+
],
136+
)
137+
def test_calc_output(
138+
self,
139+
reduce_results,
140+
tp_size,
141+
is_prefill,
142+
expected,
143+
mocker: MockerFixture,
144+
):
145+
quant_method = mocker.MagicMock()
146+
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
147+
torch.tensor([1.0, 2.0]))
148+
row_parallel_linear = self.init_row_parallel_linear(mocker)
149+
row_parallel_linear.__dict__["reduce_results"] = reduce_results
150+
row_parallel_linear.__dict__["tp_size"] = tp_size
151+
row_parallel_linear.__dict__["is_prefill"] = lambda: is_prefill # noqa
152+
row_parallel_linear.__dict__["quant_method"] = quant_method
153+
row_parallel_linear.__dict__["tp_rank"] = 0
154+
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
155+
156+
mocker.patch(
157+
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
158+
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
159+
)
160+
mocker.patch(
161+
"vllm_ascend.patch.worker.patch_common.patch_linear.tensor_model_parallel_all_reduce",
162+
side_effect=lambda output: output * 2, # noqa
163+
), # noqa
164+
mocker.patch(
165+
"torch_npu.npu_mm_all_reduce_base",
166+
side_effect=lambda input_, weight, hccl_info, bias: input_.
167+
matmul( # noqa
168+
torch.tensor([4.0, 8.0])),
169+
) # noqa
170+
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
171+
assert torch.allclose(ret, expected)
172+
173+
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
174+
mocker.patch.object(envs,
175+
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
176+
new=True)
177+
reload(patch_linear)
178+
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
179+
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
180+
patch_linear.AscendRowParallelLinear)

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@
133133
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
134134
lambda: bool(
135135
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
136+
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
137+
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
138+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
136139
}
137140

138141
# end-env-vars-definition

vllm_ascend/patch/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,19 @@
101101
# - https://github.com/vllm-project/vllm-ascend/pull/1732
102102
# Future Plan:
103103
# Revert it when the ascend scatter performance improves.
104+
# ** File: worker/patch_common/patch_linear.py **
105+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
106+
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
107+
# Why:
108+
# We need to fuse matmul and allreuce in `RowParallelLinear`
109+
# to improve performance.
110+
# How:
111+
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
112+
# In this class, we override the `forward` method to use
113+
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
114+
# Related PR (if no, explain why):
115+
# - https://github.com/vllm-project/vllm-ascend/pull/1926
116+
# Future Plan:
117+
# Validate more models in all kinds of scenario,
118+
# if performance is always improved, we can enable this patch by default and remove the env
119+
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
# patch files.
2020
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
2121
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
22+
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
2223
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2324
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
This file is a part of the vllm-ascend project.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
18+
from typing import Optional, Union
19+
20+
import torch
21+
import torch_npu
22+
import vllm
23+
from torch.distributed import ProcessGroup
24+
from torch.nn.parameter import Parameter
25+
from vllm.distributed import (get_tensor_model_parallel_rank,
26+
split_tensor_along_last_dim)
27+
from vllm.distributed.parallel_state import get_tp_group
28+
from vllm.model_executor.layers.linear import RowParallelLinear
29+
30+
from vllm_ascend import envs
31+
32+
_HCOMM_INFO = None
33+
34+
35+
class AscendRowParallelLinear(RowParallelLinear):
36+
"""
37+
AscendRowParallelLinear is a custom implementation of RowParallelLinear
38+
that overrides the forward method to handle Ascend-specific operations.
39+
"""
40+
41+
@staticmethod
42+
def get_hcomm_info(group: ProcessGroup) -> str:
43+
"""Get the HCCL communication information for the given group.
44+
45+
Args:
46+
group (ProcessGroup): The process group for which to get the HCCL communication info.
47+
48+
Returns:
49+
str: The HCCL communication name for the given group.
50+
"""
51+
global _HCOMM_INFO
52+
if _HCOMM_INFO is not None:
53+
return _HCOMM_INFO
54+
55+
rank = torch.distributed.get_rank(group)
56+
if torch.__version__ > "2.0":
57+
global_rank = torch.distributed.get_global_rank(group, rank)
58+
_HCOMM_INFO = group._get_backend(
59+
torch.device("npu")).get_hccl_comm_name(global_rank)
60+
61+
else:
62+
_HCOMM_INFO = group.get_hccl_comm_name(rank)
63+
return _HCOMM_INFO
64+
65+
def forward(
66+
self, input_: torch.Tensor
67+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
68+
"""Forward pass for the AscendRowParallelLinear layer.
69+
70+
Args:
71+
input_ (torch.Tensor): the input tensor to the layer.
72+
73+
Returns:
74+
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
75+
The output tensor after applying the linear transformation,
76+
and optionally the bias if `return_bias` is True.
77+
"""
78+
input_parallel = self.calc_input(input_)
79+
80+
# Matrix multiply.
81+
assert self.quant_method is not None
82+
# Only fuse bias add into GEMM for rank 0 (this ensures that
83+
# bias will not get added more than once in TP>1 case)
84+
output = self.calc_output(input_parallel)
85+
86+
output_bias = self.bias if self.skip_bias_add else None
87+
88+
if not self.return_bias:
89+
return output
90+
return output, output_bias
91+
92+
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
93+
"""Calculate the input tensor for parallel processing.
94+
95+
Args:
96+
input_ (torch.Tensor): the input tensor to be processed.
97+
98+
Returns:
99+
torch.Tensor: The input tensor split along the last dimension
100+
for tensor model parallelism, or the original input if not parallel.
101+
"""
102+
if self.input_is_parallel:
103+
return input_
104+
tp_rank = get_tensor_model_parallel_rank()
105+
splitted_input = split_tensor_along_last_dim(
106+
input_, num_partitions=self.tp_size)
107+
return splitted_input[tp_rank].contiguous()
108+
109+
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
110+
"""Calculate the output tensor of forward by considering
111+
fusing communication and computation.
112+
113+
Args:
114+
input_parallel (_type_): the input tensor to be processed in parallel.
115+
116+
Returns:
117+
torch.Tensor: the output tensor after applying the linear transformation
118+
and optionally handle communication between tensor model parallel ranks.
119+
"""
120+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
121+
tp_group = get_tp_group().device_group
122+
hcomm_info = self.get_hcomm_info(tp_group)
123+
if self.reduce_results and self.tp_size > 1:
124+
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
125+
self.weight.t(),
126+
hcomm_info,
127+
bias=bias_)
128+
else:
129+
output = self.quant_method.apply(self, input_parallel, bias=bias_)
130+
return output
131+
132+
133+
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
134+
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear

0 commit comments

Comments
 (0)