@@ -223,11 +223,6 @@ def _prepare_from_posids(query, key, value, position_ids):
223223 key = key .contiguous ().view (- 1 , key .size (- 2 ), key .size (- 1 ))
224224 value = value .contiguous ().view (- 1 , value .size (- 2 ), value .size (- 1 ))
225225
226- cu_seqlens_k = torch .cat (
227- [torch .tensor ([0 ], dtype = torch .int32 , device = query .device ), position_ids [:, - 1 ].cumsum (dim = 0 ) + 1 ], dim = 0
228- )
229- max_k = torch .max (position_ids , dim = 1 ).values .max ().item () + 1
230-
231226 position_ids = position_ids .flatten ()
232227 indices_q = torch .arange (position_ids .size (0 ), device = position_ids .device , dtype = torch .int32 )
233228
@@ -246,7 +241,7 @@ def _prepare_from_posids(query, key, value, position_ids):
246241 # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
247242 # for some models (e.g. qwen2-vl).
248243 max_length = cu_seq_lens .diff ().max ().item ()
249- return (query , key , value , indices_q , (cu_seq_lens , cu_seqlens_k ), (max_length , max_k ))
244+ return (query , key , value , indices_q , (cu_seq_lens , cu_seq_lens ), (max_length , max_length ))
250245
251246
252247def _prepare_flash_attention_from_position_ids (query , key , value , position_ids ):
0 commit comments