Skip to content

Commit 0125904

Browse files
committed
feat: implement matmul_allreduce in prefill phase when tensor parallel is enabled
Signed-off-by: Ronald1995 <[email protected]>
1 parent e561a2c commit 0125904

File tree

6 files changed

+321
-4
lines changed

6 files changed

+321
-4
lines changed

tests/ut/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919

2020
from vllm_ascend.utils import adapt_patch, register_ascend_customop
2121

22-
# fused moe ops test will hit the infer_schema error, we need add the patch
23-
# here to make the test pass.
24-
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
25-
2622

2723
class TestBase(unittest.TestCase):
2824

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from importlib import reload
2+
3+
import pytest
4+
import torch
5+
import vllm
6+
from pytest_mock import MockerFixture
7+
8+
from tests.ut.base import PytestBase
9+
from vllm_ascend import envs
10+
from vllm_ascend.patch.worker.patch_common import patch_linear
11+
12+
13+
class TestAscendRowParallelLinear(PytestBase):
14+
15+
def init_row_parallel_linear(self, mocker: MockerFixture):
16+
mocker.patch(
17+
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
18+
return_value=None,
19+
)
20+
mocker.patch("torch.nn.Module.__setattr__")
21+
mocker.patch("torch.nn.Module.__getattr__")
22+
mocker.patch("torch.nn.Module.__delattr__")
23+
return patch_linear.AscendRowParallelLinear(
24+
input_size=128,
25+
output_size=256,
26+
)
27+
28+
@pytest.mark.parametrize(
29+
"version, expected",
30+
[
31+
("1.0.0", 1),
32+
("2.1.0", 1),
33+
],
34+
)
35+
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
36+
mock_group = mocker.MagicMock()
37+
backend = mocker.MagicMock()
38+
backend.get_hccl_comm_name = lambda x: x
39+
mock_group._get_backend = lambda x: backend
40+
mock_group.get_hccl_comm_name = lambda x: x
41+
mocker.patch("torch.distributed.get_rank", return_value=1)
42+
mocker.patch(
43+
"torch.distributed.get_global_rank",
44+
return_value=0,
45+
)
46+
mocker.patch("torch.__version__", new=version)
47+
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
48+
mock_group)
49+
assert hcomm_info == expected
50+
51+
@pytest.mark.parametrize(
52+
"skip_bias_add, return_bias, bias, expected",
53+
[
54+
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
55+
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
56+
(
57+
True,
58+
True,
59+
torch.tensor(1.0),
60+
(torch.tensor(14.0), torch.tensor(1.0)),
61+
),
62+
],
63+
)
64+
def test_forward(
65+
self,
66+
skip_bias_add,
67+
return_bias,
68+
bias,
69+
expected,
70+
mocker: MockerFixture,
71+
):
72+
mocker_tp_group = mocker.MagicMock()
73+
mocker_tp_group.device_group = mocker.MagicMock()
74+
row_parallel_linear = self.init_row_parallel_linear(mocker)
75+
row_parallel_linear.__dict__["tp_rank"] = 0
76+
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
77+
row_parallel_linear.__dict__["return_bias"] = return_bias
78+
row_parallel_linear.__dict__["bias"] = bias
79+
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
80+
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
81+
row_parallel_linear.__dict__[
82+
"calc_output"] = lambda x: x.matmul( # noqa
83+
torch.tensor([1.0, 2.0]))
84+
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
85+
if isinstance(ret, tuple):
86+
assert torch.allclose(ret[0], expected[0])
87+
if ret[1] is None:
88+
assert ret[1] == expected[1]
89+
else:
90+
assert torch.allclose(ret[1], expected[1])
91+
else:
92+
assert torch.allclose(ret, expected)
93+
94+
@pytest.mark.parametrize(
95+
"input_is_parallel, expected",
96+
[
97+
(True, torch.tensor([10.0, 2.0])),
98+
(False, torch.tensor([10.0])),
99+
],
100+
)
101+
def test_calc_input(
102+
self,
103+
input_is_parallel,
104+
expected,
105+
mocker: MockerFixture,
106+
):
107+
row_parallel_linear = self.init_row_parallel_linear(mocker)
108+
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
109+
input_tensor = torch.Tensor([10, 2])
110+
mocker.patch(
111+
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
112+
return_value=0,
113+
)
114+
mocker.patch(
115+
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
116+
return_value=[torch.Tensor([10]),
117+
torch.Tensor([2])],
118+
)
119+
input_parallel = row_parallel_linear.calc_input(input_tensor)
120+
assert torch.allclose(input_parallel, expected)
121+
122+
@pytest.mark.parametrize(
123+
"reduce_results, tp_size, expected",
124+
[
125+
(True, 2, torch.tensor(56.0)),
126+
(True, 1, torch.tensor(14.0)),
127+
(False, 2, torch.tensor(14.0)),
128+
],
129+
)
130+
def test_calc_output(
131+
self,
132+
reduce_results,
133+
tp_size,
134+
expected,
135+
mocker: MockerFixture,
136+
):
137+
quant_method = mocker.MagicMock()
138+
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
139+
torch.tensor([1.0, 2.0]))
140+
row_parallel_linear = self.init_row_parallel_linear(mocker)
141+
row_parallel_linear.__dict__["reduce_results"] = reduce_results
142+
row_parallel_linear.__dict__["tp_size"] = tp_size
143+
row_parallel_linear.__dict__["quant_method"] = quant_method
144+
row_parallel_linear.__dict__["tp_rank"] = 0
145+
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
146+
147+
mocker.patch(
148+
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
149+
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
150+
)
151+
mocker.patch(
152+
"torch_npu.npu_mm_all_reduce_base",
153+
side_effect=lambda input_, weight, hccl_info, bias: input_.
154+
matmul( # noqa
155+
torch.tensor([4.0, 8.0])),
156+
) # noqa
157+
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
158+
assert torch.allclose(ret, expected)
159+
160+
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
161+
mocker.patch.object(envs,
162+
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
163+
new=True)
164+
reload(patch_linear)
165+
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
166+
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
167+
patch_linear.AscendRowParallelLinear)

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@
133133
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
134134
lambda: bool(
135135
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
136+
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
137+
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
138+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
136139
}
137140

138141
# end-env-vars-definition

vllm_ascend/patch/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,19 @@
101101
# - https://github.com/vllm-project/vllm-ascend/pull/1732
102102
# Future Plan:
103103
# Revert it when the ascend scatter performance improves.
104+
# ** File: worker/patch_common/patch_linear.py **
105+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
106+
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
107+
# Why:
108+
# We need to fuse matmul and allreuce in `RowParallelLinear`
109+
# to improve performance.
110+
# How:
111+
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
112+
# In this class, we override the `forward` method to use
113+
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
114+
# Related PR (if no, explain why):
115+
# - https://github.com/vllm-project/vllm-ascend/pull/1926
116+
# Future Plan:
117+
# Validate more models in all kinds of scenario,
118+
# if performance is always improved, we can enable this patch by default and remove the env
119+
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
# patch files.
2020
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
2121
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
22+
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
2223
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2324
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
This file is a part of the vllm-ascend project.
4+
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
"""
17+
18+
from typing import Optional, Union
19+
20+
import torch
21+
import torch_npu
22+
import vllm
23+
from torch.distributed import ProcessGroup
24+
from torch.nn.parameter import Parameter
25+
from vllm.distributed import (get_tensor_model_parallel_rank,
26+
split_tensor_along_last_dim)
27+
from vllm.distributed.parallel_state import get_tp_group
28+
from vllm.model_executor.layers.linear import RowParallelLinear
29+
30+
from vllm_ascend import envs
31+
32+
_HCOMM_INFO = None
33+
34+
35+
class AscendRowParallelLinear(RowParallelLinear):
36+
"""
37+
AscendRowParallelLinear is a custom implementation of RowParallelLinear
38+
that overrides the forward method to handle Ascend-specific operations.
39+
"""
40+
41+
@staticmethod
42+
def get_hcomm_info(group: ProcessGroup) -> str:
43+
"""Get the HCCL communication information for the given group.
44+
45+
Args:
46+
group (ProcessGroup): The process group for which to get the HCCL communication info.
47+
48+
Returns:
49+
str: The HCCL communication name for the given group.
50+
"""
51+
global _HCOMM_INFO
52+
if _HCOMM_INFO is not None:
53+
return _HCOMM_INFO
54+
55+
rank = torch.distributed.get_rank(group)
56+
if torch.__version__ > "2.0":
57+
global_rank = torch.distributed.get_global_rank(group, rank)
58+
_HCOMM_INFO = group._get_backend(
59+
torch.device("npu")).get_hccl_comm_name(global_rank)
60+
61+
else:
62+
_HCOMM_INFO = group.get_hccl_comm_name(rank)
63+
return _HCOMM_INFO
64+
65+
def forward(
66+
self, input_: torch.Tensor
67+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
68+
"""Forward pass for the AscendRowParallelLinear layer.
69+
70+
Args:
71+
input_ (torch.Tensor): the input tensor to the layer.
72+
73+
Returns:
74+
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
75+
The output tensor after applying the linear transformation,
76+
and optionally the bias if `return_bias` is True.
77+
"""
78+
input_parallel = self.calc_input(input_)
79+
80+
# Matrix multiply.
81+
assert self.quant_method is not None
82+
# Only fuse bias add into GEMM for rank 0 (this ensures that
83+
# bias will not get added more than once in TP>1 case)
84+
output = self.calc_output(input_parallel)
85+
86+
output_bias = self.bias if self.skip_bias_add else None
87+
88+
if not self.return_bias:
89+
return output
90+
return output, output_bias
91+
92+
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
93+
"""Calculate the input tensor for parallel processing.
94+
95+
Args:
96+
input_ (torch.Tensor): the input tensor to be processed.
97+
98+
Returns:
99+
torch.Tensor: The input tensor split along the last dimension
100+
for tensor model parallelism, or the original input if not parallel.
101+
"""
102+
if self.input_is_parallel:
103+
return input_
104+
tp_rank = get_tensor_model_parallel_rank()
105+
splitted_input = split_tensor_along_last_dim(
106+
input_, num_partitions=self.tp_size)
107+
return splitted_input[tp_rank].contiguous()
108+
109+
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
110+
"""Calculate the output tensor of forward by considering
111+
fusing communication and computation.
112+
113+
Args:
114+
input_parallel (_type_): the input tensor to be processed in parallel.
115+
116+
Returns:
117+
torch.Tensor: the output tensor after applying the linear transformation
118+
and optionally handle communication between tensor model parallel ranks.
119+
"""
120+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
121+
tp_group = get_tp_group().device_group
122+
hcomm_info = self.get_hcomm_info(tp_group)
123+
if self.reduce_results and self.tp_size > 1:
124+
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
125+
self.weight.t(),
126+
hcomm_info,
127+
bias=bias_)
128+
else:
129+
output = self.quant_method.apply(self, input_parallel, bias=bias_)
130+
return output
131+
132+
133+
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
134+
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear

0 commit comments

Comments
 (0)