@@ -223,11 +223,6 @@ def _prepare_from_posids(query, key, value, position_ids):
223
223
key = key .contiguous ().view (- 1 , key .size (- 2 ), key .size (- 1 ))
224
224
value = value .contiguous ().view (- 1 , value .size (- 2 ), value .size (- 1 ))
225
225
226
- cu_seqlens_k = torch .cat (
227
- [torch .tensor ([0 ], dtype = torch .int32 , device = query .device ), position_ids [:, - 1 ].cumsum (dim = 0 ) + 1 ], dim = 0
228
- )
229
- max_k = torch .max (position_ids , dim = 1 ).values .max ().item () + 1
230
-
231
226
position_ids = position_ids .flatten ()
232
227
indices_q = torch .arange (position_ids .size (0 ), device = position_ids .device , dtype = torch .int32 )
233
228
@@ -246,7 +241,7 @@ def _prepare_from_posids(query, key, value, position_ids):
246
241
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
247
242
# for some models (e.g. qwen2-vl).
248
243
max_length = cu_seq_lens .diff ().max ().item ()
249
- return (query , key , value , indices_q , (cu_seq_lens , cu_seqlens_k ), (max_length , max_k ))
244
+ return (query , key , value , indices_q , (cu_seq_lens , cu_seq_lens ), (max_length , max_length ))
250
245
251
246
252
247
def _prepare_flash_attention_from_position_ids (query , key , value , position_ids ):
0 commit comments