Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,4 @@ jobs:
VLLM_WORKER_MULTIPROC_METHOD: spawn
VLLM_USE_MODELSCOPE: True
run: |
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
2 changes: 2 additions & 0 deletions .github/workflows/vllm_ascend_test_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight

#pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
pytest -sv tests/e2e/multicard/test_prefix_caching.py
Expand Down
29 changes: 27 additions & 2 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"

QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]
QWEN_DENSE_MODELS = [
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
]


def test_models_distributed_QwQ():
Expand Down Expand Up @@ -170,6 +172,29 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
max_model_len=8192,
enforce_eager=enforce_eager,
dtype="auto",
tensor_parallel_size=4,
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(
model, enforce_eager):
example_prompts = [
"Hello, my name is",
]
max_tokens = 5

with VllmRunner(
snapshot_download(model),
max_model_len=8192,
enforce_eager=enforce_eager,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
13 changes: 12 additions & 1 deletion tests/ut/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):

@pytest.mark.parametrize("is_310p_return", [True, False])
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
side_effect=lambda x: None)
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done, mock_swiglu,
is_310p_return, dummy_tensor):

with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = SiluAndMul()
Expand All @@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
else:
expected_arg = dummy_tensor

# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()

# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()

# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()

actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(
actual_arg,
Expand Down
11 changes: 10 additions & 1 deletion tests/ut/ops/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
def test_RMSNorm_forward(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
mock_rmsnorm, is_310p_return, residual, dummy_tensor):

with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
Expand All @@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)

mock_maybe_chunk_residual.assert_called_once()
mock_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_maybe_chunk_residual.assert_called_once()
mock_add_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
Expand All @@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,

@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_add_rms_norm, mock_is310p):
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
Expand All @@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,

mock_maybe_chunk_residual.assert_called_once()
mock_add_rms_norm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert out_residual.size(0) == 4
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
7 changes: 6 additions & 1 deletion tests/ut/torchair/models/test_torchair_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,

@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=lambda x, residual: residual)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
Expand Down
27 changes: 25 additions & 2 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def set_ascend_forward_context(
moe_comm_method: str = "",
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
Expand Down Expand Up @@ -108,7 +110,8 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
Comment on lines +113 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE is repeated in multiple if statements. Consider defining a variable to store this value and reuse it to avoid redundancy and improve readability. This also centralizes the configuration, making it easier to manage.

Suggested change
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
dense_optimize_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
flashcomm_v1_enabled = dense_optimize_enabled and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000

tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000

Expand All @@ -122,6 +125,26 @@ def set_ascend_forward_context(
# set this for rope forward_oot using
forward_context.is_first_layer = True

# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if model_instance is not None and \
hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer

# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
Comment on lines +137 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, the condition envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE is repeated here. Consider reusing the dense_optimize_enabled variable defined earlier to maintain consistency and readability.

Suggested change
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
prefetch_mlp_enabled = dense_optimize_enabled and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500

if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled

if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens

Expand Down
9 changes: 9 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
# buffer size for gate up prefetch
"MLP_GATE_UP_PREFETCH_SIZE":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add VLLM_ASCEND_ prefix

lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
# buffer size for down proj prefetch
"MLP_DOWN_PREFETCH_SIZE":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
# Whether to enable dense model and general optimizations for better performance.
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:

from vllm_ascend.utils import is_310p

torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
return out
16 changes: 4 additions & 12 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ def forward(
import torch_npu

if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
Expand All @@ -58,6 +53,7 @@ def forward(
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
Expand All @@ -76,12 +72,7 @@ def forward_oot(

from vllm_ascend.utils import is_310p
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
if is_310p():
orig_dtype = residual.dtype
Expand All @@ -92,6 +83,7 @@ def forward_oot(
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def _forward_dense_optim(
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)

output_bias = self.bias if self.skip_bias_add else None

Expand Down
Loading
Loading