File tree Expand file tree Collapse file tree 4 files changed +107
-61
lines changed
server/text_generation_server/models Expand file tree Collapse file tree 4 files changed +107
-61
lines changed Original file line number Diff line number Diff line change @@ -710,34 +710,41 @@ def forward(
710
710
# )
711
711
if SYSTEM == "ipex" :
712
712
attn_output = torch .empty_like (query_states )
713
- ipex .llm .functional .varlen_attention (
714
- (
715
- query_states .contiguous ()
716
- if query_states .device .type == "xpu"
717
- else query_states
718
- ),
719
- (
720
- key_states .contiguous ()
721
- if key_states .device .type == "xpu"
722
- else key_states
723
- ),
724
- (
725
- value_states .contiguous ()
726
- if value_states .device .type == "xpu"
727
- else value_states
728
- ),
729
- attn_output ,
730
- cu_seqlen_q ,
731
- cu_seqlen_k ,
732
- max_q ,
733
- max_k ,
734
- 0.0 ,
735
- self .softmax_scale ,
736
- False ,
737
- causal ,
738
- False ,
739
- None ,
740
- )
713
+ if query_states .device .type == "xpu" :
714
+ ipex .llm .functional .varlen_attention (
715
+ query_states .contiguous (),
716
+ key_states .contiguous (),
717
+ value_states .contiguous (),
718
+ attn_output ,
719
+ cu_seqlen_q ,
720
+ cu_seqlen_k ,
721
+ None ,
722
+ max_q ,
723
+ max_k ,
724
+ 0.0 ,
725
+ self .softmax_scale ,
726
+ False ,
727
+ causal ,
728
+ False ,
729
+ None ,
730
+ )
731
+ else :
732
+ ipex .llm .functional .varlen_attention (
733
+ query_states ,
734
+ key_states ,
735
+ value_states ,
736
+ attn_output ,
737
+ cu_seqlen_q ,
738
+ cu_seqlen_k ,
739
+ max_q ,
740
+ max_k ,
741
+ 0.0 ,
742
+ self .softmax_scale ,
743
+ False ,
744
+ causal ,
745
+ False ,
746
+ None ,
747
+ )
741
748
else :
742
749
attn_output = flash_attn_2_cuda .varlen_fwd (
743
750
query_states ,
Original file line number Diff line number Diff line change @@ -460,22 +460,41 @@ def forward(
460
460
# execute flash attention
461
461
if SYSTEM == "ipex" :
462
462
attn_output = torch .empty_like (query )
463
- ipex .llm .functional .varlen_attention (
464
- (query .contiguous () if query .device .type == "xpu" else query ),
465
- (key .contiguous () if key .device .type == "xpu" else key ),
466
- (value .contiguous () if value .device .type == "xpu" else value ),
467
- attn_output ,
468
- cu_seqlens ,
469
- cu_seqlens ,
470
- max_seqlen ,
471
- max_seqlen ,
472
- 0.0 ,
473
- self .softmax_scale ,
474
- False ,
475
- causal ,
476
- False ,
477
- None ,
478
- )
463
+ if query .device .type == "xpu" :
464
+ ipex .llm .functional .varlen_attention (
465
+ query .contiguous (),
466
+ key .contiguous (),
467
+ value .contiguous (),
468
+ attn_output ,
469
+ cu_seqlens ,
470
+ cu_seqlens ,
471
+ None ,
472
+ max_seqlen ,
473
+ max_seqlen ,
474
+ 0.0 ,
475
+ self .softmax_scale ,
476
+ False ,
477
+ causal ,
478
+ False ,
479
+ None ,
480
+ )
481
+ else :
482
+ ipex .llm .functional .varlen_attention (
483
+ query ,
484
+ key ,
485
+ value ,
486
+ attn_output ,
487
+ cu_seqlens ,
488
+ cu_seqlens ,
489
+ max_seqlen ,
490
+ max_seqlen ,
491
+ 0.0 ,
492
+ self .softmax_scale ,
493
+ False ,
494
+ causal ,
495
+ False ,
496
+ None ,
497
+ )
479
498
else :
480
499
attn_output = flash_attn_2_cuda .varlen_fwd (
481
500
query ,
Original file line number Diff line number Diff line change @@ -130,22 +130,41 @@ def forward(
130
130
# execute flash attention
131
131
if SYSTEM == "ipex" :
132
132
attn_output = torch .empty_like (query )
133
- ipex .llm .functional .varlen_attention (
134
- (query .contiguous () if query .device .type == "xpu" else query ),
135
- (key .contiguous () if key .device .type == "xpu" else key ),
136
- (value .contiguous () if value .device .type == "xpu" else value ),
137
- attn_output ,
138
- cu_seqlens ,
139
- cu_seqlens ,
140
- max_seqlen ,
141
- max_seqlen ,
142
- 0.0 ,
143
- self .softmax_scale ,
144
- False ,
145
- causal ,
146
- False ,
147
- None ,
148
- )
133
+ if query .device .type == "xpu" :
134
+ ipex .llm .functional .varlen_attention (
135
+ query .contiguous (),
136
+ key .contiguous (),
137
+ value .contiguous (),
138
+ attn_output ,
139
+ cu_seqlens ,
140
+ cu_seqlens ,
141
+ None ,
142
+ max_seqlen ,
143
+ max_seqlen ,
144
+ 0.0 ,
145
+ self .softmax_scale ,
146
+ False ,
147
+ causal ,
148
+ False ,
149
+ None ,
150
+ )
151
+ else :
152
+ ipex .llm .functional .varlen_attention (
153
+ query ,
154
+ key ,
155
+ value ,
156
+ attn_output ,
157
+ cu_seqlens ,
158
+ cu_seqlens ,
159
+ max_seqlen ,
160
+ max_seqlen ,
161
+ 0.0 ,
162
+ self .softmax_scale ,
163
+ False ,
164
+ causal ,
165
+ False ,
166
+ None ,
167
+ )
149
168
else :
150
169
attn_output = flash_attn_2_cuda .varlen_fwd (
151
170
query ,
Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def concatenate(cls, batches):
59
59
@tracer .start_as_current_span ("filter" )
60
60
def filter (self , request_ids : List [int ]):
61
61
assert self .image_indices is not None
62
- batch = super ().filter (request_ids )
62
+ batch = super (VlmCausalLMBatch , self ).filter (request_ids )
63
63
assert self .image_indices is not None
64
64
indices = []
65
65
for i , request_id in enumerate (request_ids ):
@@ -85,6 +85,7 @@ def filter(self, request_ids: List[int]):
85
85
]
86
86
else :
87
87
batch .cross_attention_states = None
88
+ batch .pixel_values = None
88
89
return batch
89
90
90
91
@classmethod
You can’t perform that action at this time.
0 commit comments