Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/ut/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@

from vllm_ascend.utils import adapt_patch, register_ascend_customop

# fused moe ops test will hit the infer_schema error, we need add the patch
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa


class TestBase(unittest.TestCase):

Expand Down
167 changes: 167 additions & 0 deletions tests/ut/patch/worker/patch_common/test_patch_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from importlib import reload

import pytest
import torch
import vllm
from pytest_mock import MockerFixture

from tests.ut.base import PytestBase
from vllm_ascend import envs
from vllm_ascend.patch.worker.patch_common import patch_linear


class TestAscendRowParallelLinear(PytestBase):

def init_row_parallel_linear(self, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
return patch_linear.AscendRowParallelLinear(
input_size=128,
output_size=256,
)

@pytest.mark.parametrize(
"version, expected",
[
("1.0.0", 1),
("2.1.0", 1),
],
)
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
mock_group = mocker.MagicMock()
backend = mocker.MagicMock()
backend.get_hccl_comm_name = lambda x: x
mock_group._get_backend = lambda x: backend
mock_group.get_hccl_comm_name = lambda x: x
mocker.patch("torch.distributed.get_rank", return_value=1)
mocker.patch(
"torch.distributed.get_global_rank",
return_value=0,
)
mocker.patch("torch.__version__", new=version)
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
mock_group)
assert hcomm_info == expected

@pytest.mark.parametrize(
"skip_bias_add, return_bias, bias, expected",
[
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
(
True,
True,
torch.tensor(1.0),
(torch.tensor(14.0), torch.tensor(1.0)),
),
],
)
def test_forward(
self,
skip_bias_add,
return_bias,
bias,
expected,
mocker: MockerFixture,
):
mocker_tp_group = mocker.MagicMock()
mocker_tp_group.device_group = mocker.MagicMock()
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["tp_rank"] = 0
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
row_parallel_linear.__dict__["return_bias"] = return_bias
row_parallel_linear.__dict__["bias"] = bias
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
row_parallel_linear.__dict__[
"calc_output"] = lambda x: x.matmul( # noqa
torch.tensor([1.0, 2.0]))
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
if isinstance(ret, tuple):
assert torch.allclose(ret[0], expected[0])
if ret[1] is None:
assert ret[1] == expected[1]
else:
assert torch.allclose(ret[1], expected[1])
else:
assert torch.allclose(ret, expected)

@pytest.mark.parametrize(
"input_is_parallel, expected",
[
(True, torch.tensor([10.0, 2.0])),
(False, torch.tensor([10.0])),
],
)
def test_calc_input(
self,
input_is_parallel,
expected,
mocker: MockerFixture,
):
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
input_tensor = torch.Tensor([10, 2])
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
return_value=0,
)
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
return_value=[torch.Tensor([10]),
torch.Tensor([2])],
)
input_parallel = row_parallel_linear.calc_input(input_tensor)
assert torch.allclose(input_parallel, expected)

@pytest.mark.parametrize(
"reduce_results, tp_size, expected",
[
(True, 2, torch.tensor(56.0)),
(True, 1, torch.tensor(14.0)),
(False, 2, torch.tensor(14.0)),
],
)
def test_calc_output(
self,
reduce_results,
tp_size,
expected,
mocker: MockerFixture,
):
quant_method = mocker.MagicMock()
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
torch.tensor([1.0, 2.0]))
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["reduce_results"] = reduce_results
row_parallel_linear.__dict__["tp_size"] = tp_size
row_parallel_linear.__dict__["quant_method"] = quant_method
row_parallel_linear.__dict__["tp_rank"] = 0
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa

mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
)
mocker.patch(
"torch_npu.npu_mm_all_reduce_base",
side_effect=lambda input_, weight, hccl_info, bias: input_.
matmul( # noqa
torch.tensor([4.0, 8.0])),
) # noqa
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
assert torch.allclose(ret, expected)

def test_enable_allreduce_matmul(self, mocker: MockerFixture):
mocker.patch.object(envs,
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
new=True)
reload(patch_linear)
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
patch_linear.AscendRowParallelLinear)
6 changes: 5 additions & 1 deletion vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
# and the mla_pa will be the default path of deepseek decode path.
"VLLM_ASCEND_MLA_PA":
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)),
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
# this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
}

# end-env-vars-definition
Expand Down
16 changes: 16 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,19 @@
# - https://github.com/vllm-project/vllm/pull/21591
# Future Plan:
# Revert it when vLLM merge #21591 and release new version
# ** File: worker/patch_common/patch_linear.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
# Why:
# We need to fuse matmul and allreuce in `RowParallelLinear`
# to improve performance.
# How:
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
# In this class, we override the `forward` method to use
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/1926
# Future Plan:
# Validate more models in all kinds of scenario,
# if performance is always improved, we can enable this patch by default and remove the env
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/patch_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
# patch files.
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
145 changes: 145 additions & 0 deletions vllm_ascend/patch/worker/patch_common/patch_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Optional, Union

import torch
import torch_npu
import vllm
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
split_tensor_along_last_dim)
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.linear import RowParallelLinear

from vllm_ascend import envs

_HCOMM_INFO = None


class AscendRowParallelLinear(RowParallelLinear):
"""
AscendRowParallelLinear is a custom implementation of RowParallelLinear
that overrides the forward method to handle Ascend-specific operations.
"""

def __init__(self, *args, **kwargs):
"""Initialize the AscendRowParallelLinear layer.

Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
tp_group = get_tp_group().device_group
hcomm_info = self.get_hcomm_info(tp_group)
self.hcomm_info = hcomm_info
super().__init__(*args, **kwargs)
self.weight_t = self.weight.t()

@staticmethod
def get_hcomm_info(group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group.

Args:
group (ProcessGroup): The process group for which to get the HCCL communication info.

Returns:
str: The HCCL communication name for the given group.
"""
global _HCOMM_INFO
if _HCOMM_INFO is not None:
return _HCOMM_INFO

rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
_HCOMM_INFO = group._get_backend(
torch.device("npu")).get_hccl_comm_name(global_rank)

else:
_HCOMM_INFO = group.get_hccl_comm_name(rank)
return _HCOMM_INFO

def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Forward pass for the AscendRowParallelLinear layer.

Args:
input_ (torch.Tensor): the input tensor to the layer.

Returns:
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
The output tensor after applying the linear transformation,
and optionally the bias if `return_bias` is True.
"""
input_parallel = self.calc_input(input_)

# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
output = self.calc_output(input_parallel)

output_bias = self.bias if self.skip_bias_add else None

if not self.return_bias:
return output
return output, output_bias

def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
"""Calculate the input tensor for parallel processing.

Args:
input_ (torch.Tensor): the input tensor to be processed.

Returns:
torch.Tensor: The input tensor split along the last dimension
for tensor model parallelism, or the original input if not parallel.
"""
if self.input_is_parallel:
return input_
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
return splitted_input[tp_rank].contiguous()

def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
"""Calculate the output tensor of forward by considering
fusing communication and computation.

Args:
input_parallel (_type_): the input tensor to be processed in parallel.

Returns:
torch.Tensor: the output tensor after applying the linear transformation
and optionally handle communication between tensor model parallel ranks.
"""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t,
self.hcomm_info,
bias=bias_)
else:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
return output


if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
Loading