Skip to content

Commit 10b9b0b

Browse files
committed
Support flashinfer for Gemma3 prefill
Gemma3 uses bidirectional attention for images. Flashinfer supports custom masks. Hook up the mask with flashinfer, so that we do not have to use the slower SDPA implementation for prefills with images.
1 parent 32cc319 commit 10b9b0b

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

server/text_generation_server/layers/attention/flashinfer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
4545
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
4646
block_tables: torch.Tensor,
4747
cu_seqlens: torch.Tensor,
48+
custom_mask: Optional[torch.Tensor],
4849
input_lengths: torch.Tensor,
4950
num_heads: int,
5051
num_kv_heads: int,
@@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
8889
paged_kv_indptr=indptr,
8990
paged_kv_indices=block_tables,
9091
paged_kv_last_page_len=last_page_len,
92+
custom_mask=custom_mask,
9193
num_qo_heads=num_heads,
9294
num_kv_heads=num_kv_heads,
9395
head_dim_qk=head_size,

server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from text_generation_server.layers.layernorm import (
4646
FastRMSNorm,
4747
)
48+
from text_generation_server.models.globals import ATTENTION
4849
from text_generation_server.utils.weights import UnquantizedWeight
4950
from transformers.activations import ACT2FN
5051
from text_generation_server.layers.attention import (
@@ -248,7 +249,7 @@ def forward(
248249

249250
# Prefill
250251
if cu_seqlen_prefill is not None:
251-
if attention_mask is None:
252+
if attention_mask is None or ATTENTION == "flashinfer":
252253
# flash attention
253254
attn_output = attention(
254255
query=query,
@@ -701,8 +702,16 @@ def __init__(self, prefix, config, weights):
701702
)
702703

703704
def get_attention_mask(
704-
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
705+
self,
706+
input_ids: torch.Tensor,
707+
cu_seqlen_prefill: Optional[torch.Tensor],
708+
dtype: torch.dtype,
709+
bool_mask: bool = False,
705710
):
711+
image_token_mask = (input_ids == self.config.image_token_index).to(
712+
input_ids.device
713+
)
714+
706715
device = input_ids.device
707716
min_dtype = torch.finfo(dtype).min
708717

@@ -748,9 +757,10 @@ def get_attention_mask(
748757
)
749758
full_attention_mask[:, :, :, :sequence_length] = combined_mask
750759

751-
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
752-
753-
return final_attention_mask
760+
if bool_mask:
761+
return full_attention_mask
762+
else:
763+
return torch.where(full_attention_mask, 0, min_dtype).to(device)
754764

755765
def forward(
756766
self,
@@ -793,10 +803,8 @@ def forward(
793803
)
794804
attention_mask = self.get_attention_mask(
795805
input_ids,
796-
max_s,
797806
cu_seqlen_prefill,
798807
inputs_embeds.dtype,
799-
image_token_mask,
800808
)
801809
# Use flash attention for text-only input
802810
# else:

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,7 @@ def _forward_context(
24342434
input_lengths_tensor: torch.Tensor,
24352435
cache_lengths_tensor: torch.Tensor,
24362436
state: Optional[Any] = None,
2437+
attention_mask: Optional[torch.Tensor] = None,
24372438
) -> ContextManager:
24382439
if ATTENTION != "flashinfer":
24392440
return nullcontext()
@@ -2450,6 +2451,7 @@ def _forward_context(
24502451
),
24512452
block_tables=block_tables,
24522453
cu_seqlens=cu_seqlen_prefill,
2454+
custom_mask=attention_mask,
24532455
input_lengths=input_lengths_tensor + cache_lengths_tensor,
24542456
num_heads=self.num_heads,
24552457
num_kv_heads=self.num_kv_heads,

server/text_generation_server/models/vlm_causal_lm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from opentelemetry import trace
66
from typing import Iterable, Optional, Tuple, List, Type, Dict
77

8+
from torch.nn import attention
89
from transformers import PreTrainedTokenizerBase
910
from transformers.image_processing_utils import select_best_resolution
1011
from text_generation_server.pb import generate_pb2
@@ -485,6 +486,14 @@ def forward(
485486
)
486487
batch.position_ids = position_ids
487488

489+
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
490+
# Get the mask, needed for flashinfer.
491+
attention_mask = self.model.get_attention_mask(
492+
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
493+
).reshape(-1)
494+
else:
495+
attention_mask = None
496+
488497
# Try to find an associated cuda graph
489498
bs = input_ids.shape[0]
490499
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@@ -508,6 +517,7 @@ def forward(
508517
cu_seqlen_prefill=cu_seqlen_prefill,
509518
input_lengths_tensor=input_lengths,
510519
cache_lengths_tensor=cache_lengths_tensor,
520+
attention_mask=attention_mask,
511521
):
512522
seqlen = Seqlen(
513523
input_lengths=input_lengths,

0 commit comments

Comments
 (0)