From d05a5c3f0a54f2bec60ec80f9aa8c5ebe0438a15 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 29 Jun 2025 22:15:12 -0700 Subject: [PATCH 1/3] xpu add alibi_scope input in varlen_attention in ipex 2.7 while cpu does not. so split the case. Signed-off-by: Wang, Yi A --- .../models/custom_modeling/mllama.py | 63 ++++++++++--------- .../models/custom_modeling/qwen2_5_vl.py | 51 ++++++++++----- .../models/custom_modeling/qwen2_vl.py | 51 ++++++++++----- 3 files changed, 105 insertions(+), 60 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index be0a4b5d7c3..d4289015bea 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -710,34 +710,41 @@ def forward( # ) if SYSTEM == "ipex": attn_output = torch.empty_like(query_states) - ipex.llm.functional.varlen_attention( - ( - query_states.contiguous() - if query_states.device.type == "xpu" - else query_states - ), - ( - key_states.contiguous() - if key_states.device.type == "xpu" - else key_states - ), - ( - value_states.contiguous() - if value_states.device.type == "xpu" - else value_states - ), - attn_output, - cu_seqlen_q, - cu_seqlen_k, - max_q, - max_k, - 0.0, - self.softmax_scale, - False, - causal, - False, - None, - ) + if query_states.device.type == "xpu": + ipex.llm.functional.varlen_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_output, + cu_seqlen_q, + cu_seqlen_k, + None, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query_states, + key_states, + value_states, + attn_output, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query_states, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index e2fc60b198f..a9cfc0653bf 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -460,22 +460,41 @@ def forward( # execute flash attention if SYSTEM == "ipex": attn_output = torch.empty_like(query) - ipex.llm.functional.varlen_attention( - (query.contiguous() if query.device.type == "xpu" else query), - (key.contiguous() if key.device.type == "xpu" else key), - (value.contiguous() if value.device.type == "xpu" else value), - attn_output, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - 0.0, - self.softmax_scale, - False, - causal, - False, - None, - ) + if query.device.dtype == "xpu": + ipex.llm.functional.varlen_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_output, + cu_seqlens, + cu_seqlens, + None, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query, + key, + value, + attn_output, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 75f718bdf97..855eaa6a5f1 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -130,22 +130,41 @@ def forward( # execute flash attention if SYSTEM == "ipex": attn_output = torch.empty_like(query) - ipex.llm.functional.varlen_attention( - (query.contiguous() if query.device.type == "xpu" else query), - (key.contiguous() if key.device.type == "xpu" else key), - (value.contiguous() if value.device.type == "xpu" else value), - attn_output, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - 0.0, - self.softmax_scale, - False, - causal, - False, - None, - ) + if query.device.type == "xpu": + ipex.llm.functional.varlen_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_output, + cu_seqlens, + cu_seqlens, + None, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query, + key, + value, + attn_output, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) else: attn_output = flash_attn_2_cuda.varlen_fwd( query, From 9f6d1704fe63541264ddc41b4ce29700ee68fd1d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 29 Jun 2025 22:47:11 -0700 Subject: [PATCH 2/3] mllama filter crash fix Signed-off-by: Wang, Yi A --- server/text_generation_server/models/mllama_causal_lm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index af9a811cf45..a9ecef76488 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -59,7 +59,7 @@ def concatenate(cls, batches): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): assert self.image_indices is not None - batch = super().filter(request_ids) + batch = super(VlmCausalLMBatch, self).filter(request_ids) assert self.image_indices is not None indices = [] for i, request_id in enumerate(request_ids): @@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]): ] else: batch.cross_attention_states = None + batch.pixel_values = None return batch @classmethod From db11fe1f9547a6a29a1c8b13483658fc6080067e Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 29 Jun 2025 23:02:10 -0700 Subject: [PATCH 3/3] minor fix Signed-off-by: Wang, Yi A --- .../text_generation_server/models/custom_modeling/qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index a9cfc0653bf..231d02b539c 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -460,7 +460,7 @@ def forward( # execute flash attention if SYSTEM == "ipex": attn_output = torch.empty_like(query) - if query.device.dtype == "xpu": + if query.device.type == "xpu": ipex.llm.functional.varlen_attention( query.contiguous(), key.contiguous(),