Skip to content

[WIP][V0.9.1] add support for flashcomm2 in qwen2 #1850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 258 additions & 21 deletions vllm_ascend/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2Attention, Qwen2MLP
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand All @@ -27,6 +32,20 @@
from vllm_ascend.attention.attention_v1 import AscendAttentionState


def pad(tensor, x):
length = tensor.size(0)
pad_size = (x - (length % x)) % x
if pad_size > 0:
return F.pad(tensor, (0, 0, 0, pad_size)), pad_size
return tensor, pad_size


def unpad(tensor, pad_size):
if pad_size > 0:
return tensor[:-pad_size, :]
return tensor


def all_gather_and_maybe_unpad(
hidden_states: torch.Tensor,
pad_size: int,
Expand All @@ -47,30 +66,215 @@ def maybe_pad_and_reduce_scatter(
return hidden_states


class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
class CustomQwen2MLP(Qwen2MLP):

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config,
cache_config=cache_config,
super().__init__(hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
quant_config=quant_config,
prefix=prefix)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_fc = ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM
if self.enable_fc == 2:
# if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
self.down_proj = ReplicatedLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
else:
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
pad_size = 0
if self.enable_fc == 2:
# pad input because AllGather requires token_num to be divisible by tp_size
x, pad_size = pad(x, self.tp_size)
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
dist.all_to_all_single(output,
x,
group=get_tp_group().device_group)
x = output.reshape(self.tp_size, -1, output.size(-1)) \
.transpose(0, 1) \
.reshape(-1, output.size(-1)*self.tp_size)
x, _ = self.down_proj(x)
return x, pad_size


class CustomQwen2Attention(Qwen2Attention):

def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=prefix,
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_fc = ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM
if self.enable_fc == 2:
self.o_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
else:
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions,
q,
k,
cos=cos,
sin=sin,
skip_index_select=True)
attn_output = self.attn(q, k, v)
pad_size = 0
if self.enable_fc == 2:
# pad input because AllGather requires token_num to be divisible by tp_size
attn_output, pad_size = pad(attn_output, self.tp_size)
output = torch.empty(attn_output.shape,
dtype=attn_output.dtype,
device=attn_output.device)
dist.all_to_all_single(output,
attn_output,
group=get_tp_group().device_group)
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
.transpose(0, 1) \
.reshape(-1, output.size(-1)*self.tp_size)
output, _ = self.o_proj(attn_output)
return output, pad_size


class CustomQwen2DecoderLayer(nn.Module):

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)

# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY

self.self_attn = CustomQwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = CustomQwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

self.enable_fc = ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.self_attn.o_proj.reduce_results = False
self.mlp.down_proj.reduce_results = False

def pre_attention_process(self, hidden_states, residual, pad_size=0):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
hidden_states = unpad(hidden_states, pad_size)
return hidden_states, residual

def pre_mlp_process(self, hidden_states, residual, pad_size=0):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
hidden_states = unpad(hidden_states, pad_size)
return hidden_states, residual

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
flashcomm_v1_enabled: bool,
cos,
sin,
pad_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
Expand All @@ -81,34 +285,47 @@ def forward(
residual = F.pad(residual, (0, 0, 0, pad_size))
residual = torch.chunk(residual, self.tp_size,
dim=0)[self.tp_rank]
if self.enable_fc == 2:
residual, pad_size = pad(residual, self.tp_size)
chunk_size = residual.size(0) // self.tp_size
residual = residual[chunk_size * self.tp_rank:chunk_size *
(self.tp_rank + 1)]
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if self.enable_fc == 2:
hidden_states, residual = self.pre_attention_process(
hidden_states, residual, pad_size)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if flashcomm_v1_enabled:
hidden_states = all_gather_and_maybe_unpad(
hidden_states, pad_size)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, pad_size = self.self_attn(positions=positions,
hidden_states=hidden_states,
cos=cos,
sin=sin)
if flashcomm_v1_enabled:
hidden_states = maybe_pad_and_reduce_scatter(
hidden_states, pad_size)
else:
elif self.enable_fc != 2:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.enable_fc == 2:
hidden_states, residual = self.pre_mlp_process(
hidden_states, residual, pad_size)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if flashcomm_v1_enabled:
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
hidden_states = self.mlp(hidden_states)
hidden_states, pad_size = self.mlp(hidden_states)
if flashcomm_v1_enabled:
hidden_states = maybe_pad_and_reduce_scatter(
hidden_states, pad_size)
else:
elif self.enable_fc != 2:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states, residual
return hidden_states, residual, pad_size


@support_torch_compile(
Expand All @@ -132,6 +349,8 @@ def __init__(
prefix=prefix,
decoder_layer_type=decoder_layer_type)
self.tp_size = get_tensor_model_parallel_world_size()
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
self.enable_fc = ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM

def forward(
self,
Expand Down Expand Up @@ -161,12 +380,24 @@ def forward(
num_tokens = hidden_states.size(0)
pad_size = (self.tp_size -
(num_tokens % self.tp_size)) % self.tp_size

cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2,
last_dim // 2).repeat(1, 1, 2).chunk(2,
dim=-2)
# BSNH
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
1, -1, 1, last_dim).contiguous()

for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
hidden_states, residual, pad_size = layer(
positions,
hidden_states,
residual,
flashcomm_v1_enabled,
cos,
sin,
pad_size,
)
if not get_pp_group().is_last_rank:
Expand All @@ -177,6 +408,12 @@ def forward(
hidden_states, _ = self.norm(hidden_states, residual)
if flashcomm_v1_enabled:
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
if self.enable_fc == 2:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
residual = tensor_model_parallel_all_gather(residual, 0)
if pad_size > 0:
hidden_states = hidden_states[:-pad_size]
residual = residual[:-pad_size]
return hidden_states


Expand Down