diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json index 859544c89e4..be8b3882f12 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json @@ -8,61 +8,61 @@ "tokens": [ { "id": 1331, - "logprob": -0.34960938, + "logprob": -0.31835938, "special": false, "text": " people" }, { "id": 8390, - "logprob": -0.14746094, + "logprob": -0.1484375, "special": false, "text": " died" }, { "id": 528, - "logprob": -1.2265625, + "logprob": -1.1171875, "special": false, "text": " in" }, { "id": 506, - "logprob": -0.47070312, + "logprob": -0.45898438, "special": false, "text": " the" }, { "id": 3640, - "logprob": -0.5859375, + "logprob": -0.55859375, "special": false, "text": " United" }, { "id": 4184, - "logprob": -0.0027770996, + "logprob": -0.0026397705, "special": false, "text": " States" }, { "id": 236761, - "logprob": -0.34765625, + "logprob": -0.38085938, "special": false, "text": "." }, { "id": 108, - "logprob": -0.0859375, + "logprob": -0.07421875, "special": false, "text": "\n\n" }, { "id": 818, - "logprob": -1.1640625, + "logprob": -1.0859375, "special": false, "text": "The" }, { "id": 6816, - "logprob": -1.890625, + "logprob": -1.75, "special": false, "text": " generally" }, @@ -74,7 +74,7 @@ }, { "id": 10967, - "logprob": -0.90625, + "logprob": -0.9609375, "special": false, "text": " estimate" }, @@ -86,43 +86,43 @@ }, { "id": 600, - "logprob": -0.65234375, + "logprob": -0.703125, "special": false, "text": " that" }, { "id": 236743, - "logprob": -1.2109375, + "logprob": -1.171875, "special": false, "text": " " }, { "id": 236825, - "logprob": -0.00088119507, + "logprob": -0.0009918213, "special": false, "text": "6" }, { "id": 236832, - "logprob": -6.580353e-05, + "logprob": -6.389618e-05, "special": false, "text": "7" }, { "id": 236810, - "logprob": -5.2690506e-05, + "logprob": -4.7445297e-05, "special": false, "text": "5" }, { "id": 236764, - "logprob": -0.0001745224, + "logprob": -0.00017929077, "special": false, "text": "," }, { "id": 236771, - "logprob": -1.180172e-05, + "logprob": -1.4901161e-05, "special": false, "text": "0" }, @@ -140,7 +140,7 @@ }, { "id": 1331, - "logprob": -0.44921875, + "logprob": -0.45898438, "special": false, "text": " people" }, @@ -158,49 +158,49 @@ }, { "id": 506, - "logprob": -0.00034713745, + "logprob": -0.00032615662, "special": false, "text": " the" }, { "id": 3640, - "logprob": -0.028564453, + "logprob": -0.029785156, "special": false, "text": " United" }, { "id": 4184, - "logprob": -0.00012207031, + "logprob": -0.00012302399, "special": false, "text": " States" }, { "id": 236761, - "logprob": -1.15625, + "logprob": -1.1796875, "special": false, "text": "." }, { "id": 3153, - "logprob": -0.103027344, + "logprob": -0.09667969, "special": false, "text": " However" }, { "id": 236764, - "logprob": -0.009155273, + "logprob": -0.009094238, "special": false, "text": "," }, { "id": 1070, - "logprob": -0.92578125, + "logprob": -0.91015625, "special": false, "text": " some" }, { "id": 61806, - "logprob": -0.91796875, + "logprob": -0.859375, "special": false, "text": " historians" }, @@ -218,79 +218,79 @@ }, { "id": 5396, - "logprob": -0.8046875, + "logprob": -0.765625, "special": false, "text": " actual" }, { "id": 1548, - "logprob": -0.04321289, + "logprob": -0.048339844, "special": false, "text": " number" }, { "id": 1451, - "logprob": -0.66015625, + "logprob": -0.65625, "special": false, "text": " could" }, { "id": 577, - "logprob": -0.091308594, + "logprob": -0.09082031, "special": false, "text": " be" }, { "id": 618, - "logprob": -0.57421875, + "logprob": -0.625, "special": false, "text": " as" }, { "id": 1494, - "logprob": -0.00036239624, + "logprob": -0.00037193298, "special": false, "text": " high" }, { "id": 618, - "logprob": -0.0001335144, + "logprob": -0.0001296997, "special": false, "text": " as" }, { "id": 236743, - "logprob": -0.0009689331, + "logprob": -0.00093460083, "special": false, "text": " " }, { "id": 236770, - "logprob": -0.26367188, + "logprob": -0.21289062, "special": false, "text": "1" }, { "id": 236771, - "logprob": -0.17773438, + "logprob": -0.16796875, "special": false, "text": "0" }, { "id": 3625, - "logprob": -0.012084961, + "logprob": -0.0126953125, "special": false, "text": " million" }, { "id": 236761, - "logprob": -0.21289062, + "logprob": -0.22460938, "special": false, "text": "." }, { "id": 108, - "logprob": -0.37304688, + "logprob": -0.3984375, "special": false, "text": "\n\n" }, @@ -302,13 +302,13 @@ }, { "id": 1006, - "logprob": -1.3203125, + "logprob": -1.359375, "special": false, "text": " am" }, { "id": 3182, - "logprob": -1.078125, + "logprob": -1.0859375, "special": false, "text": " looking" }, @@ -320,85 +320,85 @@ }, { "id": 919, - "logprob": -1.25, + "logprob": -1.2578125, "special": false, "text": " more" }, { "id": 1938, - "logprob": -1.2421875, + "logprob": -1.3046875, "special": false, "text": " information" }, { "id": 580, - "logprob": -0.7734375, + "logprob": -0.7421875, "special": false, "text": " on" }, { "id": 672, - "logprob": -0.73046875, + "logprob": -0.78125, "special": false, "text": " this" }, { "id": 59725, - "logprob": -0.75, + "logprob": -0.7109375, "special": false, "text": " discrepancy" }, { "id": 532, - "logprob": -0.83984375, + "logprob": -0.8046875, "special": false, "text": " and" }, { "id": 506, - "logprob": -0.7109375, + "logprob": -0.71484375, "special": false, "text": " the" }, { "id": 5872, - "logprob": -1.2734375, + "logprob": -1.1640625, "special": false, "text": " factors" }, { "id": 600, - "logprob": -0.22851562, + "logprob": -0.20410156, "special": false, "text": " that" }, { "id": 19263, - "logprob": -1.1640625, + "logprob": -1.1484375, "special": false, "text": " contributed" }, { "id": 531, - "logprob": -0.0010757446, + "logprob": -0.000957489, "special": false, "text": " to" }, { "id": 506, - "logprob": -0.18945312, + "logprob": -0.19921875, "special": false, "text": " the" }, { "id": 5777, - "logprob": -1.2734375, + "logprob": -1.171875, "special": false, "text": " wide" }, { "id": 2644, - "logprob": -0.01940918, + "logprob": -0.020141602, "special": false, "text": " range" }, @@ -410,31 +410,31 @@ }, { "id": 14287, - "logprob": -0.032470703, + "logprob": -0.03564453, "special": false, "text": " estimates" }, { "id": 236761, - "logprob": -0.010375977, + "logprob": -0.010620117, "special": false, "text": "." }, { "id": 108, - "logprob": -0.06591797, + "logprob": -0.060302734, "special": false, "text": "\n\n" }, { "id": 8291, - "logprob": -0.8046875, + "logprob": -0.7421875, "special": false, "text": "Here" }, { "id": 236789, - "logprob": -0.23828125, + "logprob": -0.24023438, "special": false, "text": "'" }, @@ -446,55 +446,55 @@ }, { "id": 496, - "logprob": -0.17480469, + "logprob": -0.16992188, "special": false, "text": " a" }, { "id": 25890, - "logprob": -0.087402344, + "logprob": -0.06933594, "special": false, "text": " breakdown" }, { "id": 529, - "logprob": -0.0021209717, + "logprob": -0.002243042, "special": false, "text": " of" }, { "id": 506, - "logprob": -0.19140625, + "logprob": -0.18554688, "special": false, "text": " the" }, { "id": 5872, - "logprob": -1.0078125, + "logprob": -0.9921875, "special": false, "text": " factors" }, { "id": 20894, - "logprob": -0.26367188, + "logprob": -0.25976562, "special": false, "text": " contributing" }, { "id": 531, - "logprob": -9.250641e-05, + "logprob": -8.440018e-05, "special": false, "text": " to" }, { "id": 506, - "logprob": -0.008666992, + "logprob": -0.009765625, "special": false, "text": " the" }, { "id": 5777, - "logprob": -0.6171875, + "logprob": -0.67578125, "special": false, "text": " wide" }, @@ -506,31 +506,31 @@ }, { "id": 529, - "logprob": -0.016723633, + "logprob": -0.014831543, "special": false, "text": " of" }, { "id": 14287, - "logprob": -0.011352539, + "logprob": -0.012329102, "special": false, "text": " estimates" }, { "id": 573, - "logprob": -0.30664062, + "logprob": -0.3125, "special": false, "text": " for" }, { "id": 506, - "logprob": -0.21386719, + "logprob": -0.21484375, "special": false, "text": " the" }, { "id": 236743, - "logprob": -0.35351562, + "logprob": -0.43359375, "special": false, "text": " " }, @@ -560,43 +560,43 @@ }, { "id": 7745, - "logprob": -0.70703125, + "logprob": -0.703125, "special": false, "text": " flu" }, { "id": 10248, - "logprob": -0.015258789, + "logprob": -0.013427734, "special": false, "text": " pandemic" }, { "id": 4355, - "logprob": -0.83203125, + "logprob": -0.6953125, "special": false, "text": " death" }, { "id": 25363, - "logprob": -7.43866e-05, + "logprob": -6.771088e-05, "special": false, "text": " toll" }, { "id": 528, - "logprob": -0.08496094, + "logprob": -0.076171875, "special": false, "text": " in" }, { "id": 506, - "logprob": -6.67572e-06, + "logprob": -7.2717667e-06, "special": false, "text": " the" }, { "id": 3640, - "logprob": -0.0059509277, + "logprob": -0.0052490234, "special": false, "text": " United" }, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json index afbfba30ab9..cd1a598e6e2 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json @@ -1,11 +1,11 @@ { "choices": [ { - "finish_reason": "stop", + "finish_reason": "length", "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?", + "content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1741965892, + "created": 1744396706, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { - "completion_tokens": 98, + "completion_tokens": 100, "prompt_tokens": 277, - "total_tokens": 375 + "total_tokens": 377 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json index 1b97d2615c4..a1d3ae782d0 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?", + "content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1741966313, + "created": 1744396703, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { - "completion_tokens": 67, + "completion_tokens": 78, "prompt_tokens": 277, - "total_tokens": 344 + "total_tokens": 355 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json index cd786b3ce63..a839d7aac3c 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json @@ -13,11 +13,11 @@ "usage": null } ], - "created": 1741964480, + "created": 1744396699, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { "completion_tokens": 74, "prompt_tokens": 275, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json index 5ed2c4507cb..c7215c930bc 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json @@ -13,11 +13,11 @@ "usage": null } ], - "created": 1741964477, + "created": 1744396697, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { "completion_tokens": 75, "prompt_tokens": 279, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index acff8573012..e481d7eb914 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -260,11 +260,22 @@ struct Config { impl Config { fn get_head_dim(&self) -> Option { - self.head_dim.or_else(|| { - self.text_config - .as_ref() - .and_then(|text_config| text_config.head_dim) - }) + if let Some(head_dim) = self.head_dim { + return Some(head_dim); + } + + let text_config = self.text_config.as_ref()?; + if let Some(head_size) = text_config.head_dim { + return Some(head_size); + } + + match self.model_type.as_deref() { + // We special-case gemma3 here, since we need flashinfer for + // handling bidirectional masks. And flashinfer can only be + // used when the head size is known. + Some("gemma3") => Some(256), + _ => None, + } } fn flop(&self) -> Option { diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 9479b606717..f78475d5179 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state( state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, block_tables: torch.Tensor, cu_seqlens: torch.Tensor, + custom_mask: Optional[torch.Tensor], input_lengths: torch.Tensor, num_heads: int, num_kv_heads: int, @@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state( paged_kv_indptr=indptr, paged_kv_indices=block_tables, paged_kv_last_page_len=last_page_len, + custom_mask=custom_mask, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 70fe9a3db5d..58afd643033 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.models.globals import ATTENTION from text_generation_server.utils.weights import UnquantizedWeight from transformers.activations import ACT2FN from text_generation_server.layers.attention import ( @@ -248,7 +249,7 @@ def forward( # Prefill if cu_seqlen_prefill is not None: - if attention_mask is None: + if attention_mask is None or ATTENTION == "flashinfer": # flash attention attn_output = attention( query=query, @@ -701,8 +702,16 @@ def __init__(self, prefix, config, weights): ) def get_attention_mask( - self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask + self, + input_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + dtype: torch.dtype, + bool_mask: bool = False, ): + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + device = input_ids.device min_dtype = torch.finfo(dtype).min @@ -748,9 +757,10 @@ def get_attention_mask( ) full_attention_mask[:, :, :, :sequence_length] = combined_mask - final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) - - return final_attention_mask + if bool_mask: + return full_attention_mask + else: + return torch.where(full_attention_mask, 0, min_dtype).to(device) def forward( self, @@ -793,10 +803,8 @@ def forward( ) attention_mask = self.get_attention_mask( input_ids, - max_s, cu_seqlen_prefill, inputs_embeds.dtype, - image_token_mask, ) # Use flash attention for text-only input # else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7c5a374bcc..a28ef3810c8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2434,6 +2434,7 @@ def _forward_context( input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() @@ -2450,6 +2451,7 @@ def _forward_context( ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, + custom_mask=attention_mask, input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5f8eb9060dc..2b1e01dfad9 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -485,6 +485,14 @@ def forward( ) batch.position_ids = position_ids + if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: + # Get the mask, needed for flashinfer. + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ).reshape(-1) + else: + attention_mask = None + # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -508,6 +516,7 @@ def forward( cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, + attention_mask=attention_mask, ): seqlen = Seqlen( input_lengths=input_lengths,