Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions vllm/model_executor/layers/attention/mm_encoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_flashinfer_wrapper,
vit_torch_sdpa_wrapper,
)

Expand All @@ -31,6 +32,7 @@ def __init__(
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
workspace_buffer: torch.Tensor | None = None,
) -> None:
"""
Args:
Expand All @@ -40,14 +42,17 @@ def __init__(
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
workspace_buffer: Pre-allocated workspace buffer for FlashInfer
cuDNN backend.
"""
super().__init__()

self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.scale = 1.0 / (head_size**0.5) if scale is None else scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
self.workspace_buffer = workspace_buffer

assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
Expand Down Expand Up @@ -165,13 +170,34 @@ def _forward_fa(
output = output.reshape(bsz, q_len, -1)
return output

def _forward_flashinfer(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return vit_flashinfer_wrapper(
q=query,
k=key,
v=value,
scale=self.scale,
workspace_buffer=self.workspace_buffer,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)

def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)

Expand All @@ -181,10 +207,15 @@ def forward_cuda(
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
return self._forward_flashinfer(
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
Expand All @@ -199,7 +230,8 @@ def forward_cpu(
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)

Expand All @@ -209,7 +241,8 @@ def forward_xpu(
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
workspace_buffer: torch.Tensor | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
Expand Down Expand Up @@ -346,6 +347,7 @@ def __init__(
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn",
workspace_buffer=workspace_buffer,
)

self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
Expand All @@ -357,6 +359,7 @@ def forward(
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
Expand Down Expand Up @@ -398,6 +401,7 @@ def forward(
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)

context_layer = einops.rearrange(
Expand Down
111 changes: 105 additions & 6 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.distributed import get_pp_group, parallel_state
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
Expand Down Expand Up @@ -138,6 +138,18 @@
# of the maximum size.
DUMMY_VIDEO_NUM_FRAMES = 2048

# Batch buckets for cuDNN graph caching - graphs are cached per bucket size.
# This avoids creating a new graph for each unique batch size at runtime.
BATCH_BUCKETS = [8, 16, 32, 64]

# Pre-allocated workspace buffer size for FlashInfer cuDNN backend (128MB).
FLASHINFER_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024

# Fixed max_seqlen for FlashInfer cuDNN to avoid graph recompilation.
# The actual sequence lengths are passed separately via actual_seq_lens_q/kv.
# TODO: use the real max_seqlen once cuDNN compilation is optimized.
FLASHINFER_MAX_SEQLEN_CUDNN = 128 * 1024


class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
Expand Down Expand Up @@ -215,6 +227,7 @@ def __init__(
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
workspace_buffer: torch.Tensor | None = None,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -227,6 +240,7 @@ def __init__(
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
workspace_buffer=workspace_buffer,
)
self.mlp = Qwen3_VisionMLP(
dim,
Expand All @@ -244,13 +258,15 @@ def forward(
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)

x = x + self.mlp(self.norm2(x))
Expand Down Expand Up @@ -332,6 +348,13 @@ def __init__(
)
self.num_grid_per_side = int(self.num_position_embeddings**0.5)

use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)

# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = vision_config.out_hidden_size * (
Expand Down Expand Up @@ -389,10 +412,20 @@ def __init__(
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASHINFER,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)

workspace_buffer = (
None
if self.attn_backend != AttentionBackendEnum.FLASHINFER
else torch.zeros(
FLASHINFER_WORKSPACE_SIZE_BYTES, dtype=torch.uint8, device=self.device
)
)

self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
Expand All @@ -403,6 +436,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
workspace_buffer=workspace_buffer,
)
for layer_idx in range(vision_config.depth)
]
Expand Down Expand Up @@ -526,13 +560,63 @@ def compute_attn_mask_seqlen(
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen

def add_padding_to_fi_seqlens(
self, seq: np.ndarray, batch_size: int, padding_value: int
) -> np.ndarray:
"""Pad sequence lengths to the next batch bucket size for cuDNN graph
caching. This avoids creating a new cuDNN graph for each unique batch
size at runtime."""
batch_size_padded = next(
(b for b in BATCH_BUCKETS if b >= batch_size), BATCH_BUCKETS[-1]
)
if batch_size_padded == batch_size:
return seq
return np.concatenate(
[
seq,
np.full(
(batch_size_padded - batch_size,), padding_value, dtype=seq.dtype
),
]
)

def compute_flashinfer_cu_seqlens(
self,
cu_seqlens: np.ndarray,
rotary_pos_emb_cos: torch.Tensor | None = None,
rotary_pos_emb_sin: torch.Tensor | None = None,
) -> np.ndarray:
"""Compute the 3-section cu_seqlens format required by the cuDNN batch
attention API. The format is [batch_offsets_qk, batch_offsets_v,
batch_offsets_o] concatenated together."""
batch_size = len(cu_seqlens) - 1
scale = self.hidden_size // self.tp_size
cu_seqlens = cu_seqlens * scale
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
cu_seqlens_qk = cu_seqlens * 2
else:
cu_seqlens_qk = cu_seqlens * 3
cu_seqlens_v = cu_seqlens * 3
cu_seqlens_o = cu_seqlens
cu_seqlens_qk = self.add_padding_to_fi_seqlens(
cu_seqlens_qk, batch_size, cu_seqlens_qk[-1]
)
cu_seqlens_v = self.add_padding_to_fi_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
)
cu_seqlens_o = self.add_padding_to_fi_seqlens(
cu_seqlens_o, batch_size, cu_seqlens_o[-1]
)
return np.concatenate([cu_seqlens_qk, cu_seqlens_v, cu_seqlens_o])

def forward(
self,
x: torch.Tensor,
Expand All @@ -556,11 +640,25 @@ def forward(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
sequence_lengths = self.add_padding_to_fi_seqlens(
sequence_lengths, len(sequence_lengths), 0
)
cu_seqlens = self.compute_flashinfer_cu_seqlens(
cu_seqlens, rotary_pos_emb_cos, rotary_pos_emb_sin
)
cu_seqlens = torch.from_numpy(cu_seqlens)
sequence_lengths = torch.from_numpy(sequence_lengths)

hidden_states = hidden_states.unsqueeze(1)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen = (
torch.tensor(FLASHINFER_MAX_SEQLEN_CUDNN, device=self.device)
if self.attn_backend == AttentionBackendEnum.FLASHINFER
else self.compute_attn_mask_seqlen(cu_seqlens)
)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
sequence_lengths = sequence_lengths.to(self.device, non_blocking=True)

deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
Expand All @@ -570,6 +668,7 @@ def forward(
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
]

@classmethod
Expand Down
Loading