Skip to content

Commit c46c17d

Browse files
authored
revert change to cu_seqlen_k and max_k when preparing from position_ids (#39653)
1 parent 4600c27 commit c46c17d

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,6 @@ def _prepare_from_posids(query, key, value, position_ids):
223223
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
224224
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
225225

226-
cu_seqlens_k = torch.cat(
227-
[torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0
228-
)
229-
max_k = torch.max(position_ids, dim=1).values.max().item() + 1
230-
231226
position_ids = position_ids.flatten()
232227
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
233228

@@ -246,7 +241,7 @@ def _prepare_from_posids(query, key, value, position_ids):
246241
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
247242
# for some models (e.g. qwen2-vl).
248243
max_length = cu_seq_lens.diff().max().item()
249-
return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k))
244+
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
250245

251246

252247
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):

0 commit comments

Comments
 (0)