Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions server/text_generation_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,34 +710,41 @@ def forward(
# )
if SYSTEM == "ipex":
attn_output = torch.empty_like(query_states)
ipex.llm.functional.varlen_attention(
(
query_states.contiguous()
if query_states.device.type == "xpu"
else query_states
),
(
key_states.contiguous()
if key_states.device.type == "xpu"
else key_states
),
(
value_states.contiguous()
if value_states.device.type == "xpu"
else value_states
),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query_states.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query_states,
key_states,
value_states,
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,
Expand Down
51 changes: 35 additions & 16 deletions server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,41 @@ def forward(
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
Expand Down
51 changes: 35 additions & 16 deletions server/text_generation_server/models/custom_modeling/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,41 @@ def forward(
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/mllama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def concatenate(cls, batches):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
batch = super(VlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
Expand All @@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]):
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch

@classmethod
Expand Down