Skip to content

Commit 5284b5c

Browse files
authored
Multi modality fix (#3283)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 6a2fa83 commit 5284b5c

File tree

4 files changed

+107
-61
lines changed

4 files changed

+107
-61
lines changed

server/text_generation_server/models/custom_modeling/mllama.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -710,34 +710,41 @@ def forward(
710710
# )
711711
if SYSTEM == "ipex":
712712
attn_output = torch.empty_like(query_states)
713-
ipex.llm.functional.varlen_attention(
714-
(
715-
query_states.contiguous()
716-
if query_states.device.type == "xpu"
717-
else query_states
718-
),
719-
(
720-
key_states.contiguous()
721-
if key_states.device.type == "xpu"
722-
else key_states
723-
),
724-
(
725-
value_states.contiguous()
726-
if value_states.device.type == "xpu"
727-
else value_states
728-
),
729-
attn_output,
730-
cu_seqlen_q,
731-
cu_seqlen_k,
732-
max_q,
733-
max_k,
734-
0.0,
735-
self.softmax_scale,
736-
False,
737-
causal,
738-
False,
739-
None,
740-
)
713+
if query_states.device.type == "xpu":
714+
ipex.llm.functional.varlen_attention(
715+
query_states.contiguous(),
716+
key_states.contiguous(),
717+
value_states.contiguous(),
718+
attn_output,
719+
cu_seqlen_q,
720+
cu_seqlen_k,
721+
None,
722+
max_q,
723+
max_k,
724+
0.0,
725+
self.softmax_scale,
726+
False,
727+
causal,
728+
False,
729+
None,
730+
)
731+
else:
732+
ipex.llm.functional.varlen_attention(
733+
query_states,
734+
key_states,
735+
value_states,
736+
attn_output,
737+
cu_seqlen_q,
738+
cu_seqlen_k,
739+
max_q,
740+
max_k,
741+
0.0,
742+
self.softmax_scale,
743+
False,
744+
causal,
745+
False,
746+
None,
747+
)
741748
else:
742749
attn_output = flash_attn_2_cuda.varlen_fwd(
743750
query_states,

server/text_generation_server/models/custom_modeling/qwen2_5_vl.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,22 +460,41 @@ def forward(
460460
# execute flash attention
461461
if SYSTEM == "ipex":
462462
attn_output = torch.empty_like(query)
463-
ipex.llm.functional.varlen_attention(
464-
(query.contiguous() if query.device.type == "xpu" else query),
465-
(key.contiguous() if key.device.type == "xpu" else key),
466-
(value.contiguous() if value.device.type == "xpu" else value),
467-
attn_output,
468-
cu_seqlens,
469-
cu_seqlens,
470-
max_seqlen,
471-
max_seqlen,
472-
0.0,
473-
self.softmax_scale,
474-
False,
475-
causal,
476-
False,
477-
None,
478-
)
463+
if query.device.type == "xpu":
464+
ipex.llm.functional.varlen_attention(
465+
query.contiguous(),
466+
key.contiguous(),
467+
value.contiguous(),
468+
attn_output,
469+
cu_seqlens,
470+
cu_seqlens,
471+
None,
472+
max_seqlen,
473+
max_seqlen,
474+
0.0,
475+
self.softmax_scale,
476+
False,
477+
causal,
478+
False,
479+
None,
480+
)
481+
else:
482+
ipex.llm.functional.varlen_attention(
483+
query,
484+
key,
485+
value,
486+
attn_output,
487+
cu_seqlens,
488+
cu_seqlens,
489+
max_seqlen,
490+
max_seqlen,
491+
0.0,
492+
self.softmax_scale,
493+
False,
494+
causal,
495+
False,
496+
None,
497+
)
479498
else:
480499
attn_output = flash_attn_2_cuda.varlen_fwd(
481500
query,

server/text_generation_server/models/custom_modeling/qwen2_vl.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,22 +130,41 @@ def forward(
130130
# execute flash attention
131131
if SYSTEM == "ipex":
132132
attn_output = torch.empty_like(query)
133-
ipex.llm.functional.varlen_attention(
134-
(query.contiguous() if query.device.type == "xpu" else query),
135-
(key.contiguous() if key.device.type == "xpu" else key),
136-
(value.contiguous() if value.device.type == "xpu" else value),
137-
attn_output,
138-
cu_seqlens,
139-
cu_seqlens,
140-
max_seqlen,
141-
max_seqlen,
142-
0.0,
143-
self.softmax_scale,
144-
False,
145-
causal,
146-
False,
147-
None,
148-
)
133+
if query.device.type == "xpu":
134+
ipex.llm.functional.varlen_attention(
135+
query.contiguous(),
136+
key.contiguous(),
137+
value.contiguous(),
138+
attn_output,
139+
cu_seqlens,
140+
cu_seqlens,
141+
None,
142+
max_seqlen,
143+
max_seqlen,
144+
0.0,
145+
self.softmax_scale,
146+
False,
147+
causal,
148+
False,
149+
None,
150+
)
151+
else:
152+
ipex.llm.functional.varlen_attention(
153+
query,
154+
key,
155+
value,
156+
attn_output,
157+
cu_seqlens,
158+
cu_seqlens,
159+
max_seqlen,
160+
max_seqlen,
161+
0.0,
162+
self.softmax_scale,
163+
False,
164+
causal,
165+
False,
166+
None,
167+
)
149168
else:
150169
attn_output = flash_attn_2_cuda.varlen_fwd(
151170
query,

server/text_generation_server/models/mllama_causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def concatenate(cls, batches):
5959
@tracer.start_as_current_span("filter")
6060
def filter(self, request_ids: List[int]):
6161
assert self.image_indices is not None
62-
batch = super().filter(request_ids)
62+
batch = super(VlmCausalLMBatch, self).filter(request_ids)
6363
assert self.image_indices is not None
6464
indices = []
6565
for i, request_id in enumerate(request_ids):
@@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]):
8585
]
8686
else:
8787
batch.cross_attention_states = None
88+
batch.pixel_values = None
8889
return batch
8990

9091
@classmethod

0 commit comments

Comments
 (0)