Skip to content

Commit 73b7ef7

Browse files
committed
feat: implement matmul_allreduce in prefill phase when tensor parallel is enabled
Signed-off-by: Ronald1995 <[email protected]>
1 parent 8cfd257 commit 73b7ef7

File tree

4 files changed

+308
-0
lines changed

4 files changed

+308
-0
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from importlib import reload
2+
3+
import pytest
4+
import torch
5+
from pytest_mock import MockerFixture
6+
from vllm.model_executor.layers.linear import RowParallelLinear
7+
8+
from vllm_ascend import envs
9+
from vllm_ascend.patch.worker.patch_common import patch_linear
10+
11+
12+
class TestAscendRowParallelLinear:
13+
14+
def init_row_parallel_linear(self, mocker: MockerFixture):
15+
mocker.patch(
16+
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
17+
return_value=None,
18+
)
19+
mocker.patch("torch.nn.Module.__setattr__")
20+
mocker.patch("torch.nn.Module.__getattr__")
21+
mocker.patch("torch.nn.Module.__delattr__")
22+
return patch_linear.AscendRowParallelLinear(
23+
input_size=128,
24+
output_size=256,
25+
)
26+
27+
@pytest.mark.parametrize(
28+
"version, expected",
29+
[
30+
("1.0.0", 1),
31+
("2.1.0", 0),
32+
],
33+
)
34+
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
35+
mock_group = mocker.MagicMock()
36+
backend = mocker.MagicMock()
37+
backend.get_hccl_comm_name = lambda x: x
38+
mock_group._get_backend = lambda x: backend
39+
mock_group.get_hccl_comm_name = lambda x: x
40+
mocker.patch("torch.distributed.get_rank", return_value=1)
41+
mocker.patch("torch.distributed.get_global_rank", return_value=0)
42+
mocker.patch("torch.__version__", new=version)
43+
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
44+
mock_group)
45+
assert hcomm_info == expected
46+
47+
@pytest.mark.parametrize(
48+
"skip_bias_add, return_bias, bias, expected",
49+
[
50+
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
51+
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
52+
(
53+
True,
54+
True,
55+
torch.tensor(1.0),
56+
(torch.tensor(14.0), torch.tensor(1.0)),
57+
),
58+
],
59+
)
60+
def test_forward(
61+
self,
62+
skip_bias_add,
63+
return_bias,
64+
bias,
65+
expected,
66+
mocker: MockerFixture,
67+
):
68+
mocker_tp_group = mocker.MagicMock()
69+
mocker_tp_group.device_group = mocker.MagicMock()
70+
row_parallel_linear = self.init_row_parallel_linear(mocker)
71+
row_parallel_linear.__dict__["tp_rank"] = 0
72+
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
73+
row_parallel_linear.__dict__["return_bias"] = return_bias
74+
row_parallel_linear.__dict__["bias"] = bias
75+
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
76+
row_parallel_linear.__dict__["calc_input"] = lambda x: x
77+
row_parallel_linear.__dict__["calc_output"] = lambda x: x.matmul(
78+
torch.tensor([1.0, 2.0]))
79+
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
80+
if isinstance(ret, tuple):
81+
assert torch.allclose(ret[0], expected[0])
82+
if ret[1] is None:
83+
assert ret[1] == expected[1]
84+
else:
85+
assert torch.allclose(ret[1], expected[1])
86+
else:
87+
assert torch.allclose(ret, expected)
88+
89+
@pytest.mark.parametrize(
90+
"input_is_parallel, expected",
91+
[
92+
(True, torch.tensor([10.0, 2.0])),
93+
(False, torch.tensor([10.0])),
94+
],
95+
)
96+
def test_calc_input(
97+
self,
98+
input_is_parallel,
99+
expected,
100+
mocker: MockerFixture,
101+
):
102+
row_parallel_linear = self.init_row_parallel_linear(mocker)
103+
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
104+
input_tensor = torch.Tensor([10, 2])
105+
mocker.patch(
106+
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
107+
return_value=0,
108+
)
109+
mocker.patch(
110+
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
111+
return_value=[torch.Tensor([10]),
112+
torch.Tensor([2])],
113+
)
114+
input_parallel = row_parallel_linear.calc_input(input_tensor)
115+
assert torch.allclose(input_parallel, expected)
116+
117+
@pytest.mark.parametrize(
118+
"reduce_results, tp_size, is_prefill, expected",
119+
[
120+
(True, 2, True, torch.tensor(56.0)),
121+
(True, 2, False, torch.tensor(28.0)),
122+
(True, 1, False, torch.tensor(14.0)),
123+
(False, 2, False, torch.tensor(14.0)),
124+
],
125+
)
126+
def test_calc_output(
127+
self,
128+
reduce_results,
129+
tp_size,
130+
is_prefill,
131+
expected,
132+
mocker: MockerFixture,
133+
):
134+
quant_method = mocker.MagicMock()
135+
quant_method.apply = lambda self, x, bias=None: x.matmul(
136+
torch.tensor([1.0, 2.0]))
137+
row_parallel_linear = self.init_row_parallel_linear(mocker)
138+
row_parallel_linear.__dict__["reduce_results"] = reduce_results
139+
row_parallel_linear.__dict__["tp_size"] = tp_size
140+
row_parallel_linear.__dict__["is_prefill"] = lambda: is_prefill
141+
row_parallel_linear.__dict__["quant_method"] = quant_method
142+
row_parallel_linear.__dict__["tp_rank"] = 0
143+
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None
144+
145+
mocker.patch(
146+
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
147+
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
148+
)
149+
mocker.patch(
150+
"vllm_ascend.patch.worker.patch_common.patch_linear.tensor_model_parallel_all_reduce",
151+
side_effect=lambda output: output * 2,
152+
), # noqa
153+
mocker.patch(
154+
"torch_npu.npu_mm_all_reduce_base",
155+
side_effect=lambda input_, weight, hccl_info, bias: input_.matmul(
156+
torch.tensor([4.0, 8.0])),
157+
) # noqa
158+
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
159+
assert torch.allclose(ret, expected)
160+
161+
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
162+
mocker.patch.object(envs,
163+
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
164+
new=True)
165+
reload(patch_linear)
166+
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
167+
assert id(RowParallelLinear) == id(
168+
patch_linear.AscendRowParallelLinear)

vllm_ascend/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@
133133
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
134134
lambda: bool(
135135
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
136+
# Whether to enable MatmulAllReduce fusion kernal when tensor parallel is enabled.
137+
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
138+
lambda: bool(
139+
int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
136140
}
137141

138142
# end-env-vars-definition

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,7 @@
1919
# patch files.
2020
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
2121
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
22+
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
2223
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2324
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
25+
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
This file is a part of the vllm-ascend project.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
18+
from typing import Optional, Union
19+
20+
import torch
21+
import torch_npu
22+
import vllm
23+
from torch.distributed import ProcessGroup
24+
from torch.nn.parameter import Parameter
25+
from vllm.distributed import (get_tensor_model_parallel_rank,
26+
split_tensor_along_last_dim)
27+
from vllm.distributed.parallel_state import get_tp_group
28+
from vllm.model_executor.layers.linear import RowParallelLinear
29+
30+
from vllm_ascend import envs
31+
32+
_HCOMM_INFO = None
33+
34+
35+
class AscendRowParallelLinear(RowParallelLinear):
36+
"""
37+
AscendRowParallelLinear is a custom implementation of RowParallelLinear
38+
that overrides the forward method to handle Ascend-specific operations.
39+
"""
40+
41+
@staticmethod
42+
def get_hcomm_info(group: ProcessGroup) -> str:
43+
"""Get the HCCL communication information for the given group.
44+
45+
Args:
46+
group (ProcessGroup): The process group for which to get the HCCL communication info.
47+
48+
Returns:
49+
str: The HCCL communication name for the given group.
50+
"""
51+
global _HCOMM_INFO
52+
if _HCOMM_INFO is not None:
53+
return _HCOMM_INFO
54+
55+
rank = torch.distributed.get_rank(group)
56+
if torch.__version__ > "2.0":
57+
global_rank = torch.distributed.get_global_rank(group, rank)
58+
_HCOMM_INFO = group._get_backend(
59+
torch.device("npu")).get_hccl_comm_name(global_rank)
60+
61+
else:
62+
_HCOMM_INFO = group.get_hccl_comm_name(rank)
63+
return _HCOMM_INFO
64+
65+
def forward(
66+
self, input_: torch.Tensor
67+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
68+
"""Forward pass for the AscendRowParallelLinear layer.
69+
70+
Args:
71+
input_ (torch.Tensor): the input tensor to the layer.
72+
73+
Returns:
74+
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
75+
The output tensor after applying the linear transformation,
76+
and optionally the bias if `return_bias` is True.
77+
"""
78+
input_parallel = self.calc_input(input_)
79+
80+
# Matrix multiply.
81+
assert self.quant_method is not None
82+
# Only fuse bias add into GEMM for rank 0 (this ensures that
83+
# bias will not get added more than once in TP>1 case)
84+
output = self.calc_output(input_parallel)
85+
86+
output_bias = self.bias if self.skip_bias_add else None
87+
88+
if not self.return_bias:
89+
return output
90+
return output, output_bias
91+
92+
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
93+
"""Calculate the input tensor for parallel processing.
94+
95+
Args:
96+
input_ (torch.Tensor): the input tensor to be processed.
97+
98+
Returns:
99+
torch.Tensor: The input tensor split along the last dimension
100+
for tensor model parallelism, or the original input if not parallel.
101+
"""
102+
if self.input_is_parallel:
103+
return input_
104+
tp_rank = get_tensor_model_parallel_rank()
105+
splitted_input = split_tensor_along_last_dim(
106+
input_, num_partitions=self.tp_size)
107+
return splitted_input[tp_rank].contiguous()
108+
109+
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
110+
"""Calculate the output tensor of forward by considering
111+
fusing communication and computation.
112+
113+
Args:
114+
input_parallel (_type_): the input tensor to be processed in parallel.
115+
116+
Returns:
117+
torch.Tensor: the output tensor after applying the linear transformation
118+
and optionally handle communication between tensor model parallel ranks.
119+
"""
120+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
121+
tp_group = get_tp_group().device_group
122+
hcomm_info = self.get_hcomm_info(tp_group)
123+
if self.reduce_results and self.tp_size > 1:
124+
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
125+
self.weight.t(),
126+
hcomm_info,
127+
bias=bias_)
128+
else:
129+
output = self.quant_method.apply(self, input_parallel, bias=bias_)
130+
return output
131+
132+
133+
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
134+
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear

0 commit comments

Comments
 (0)