Skip to content

Commit 05b47e2

Browse files
winglianzaristei
authored andcommitted
revert behavior of _prepare_from_posids (huggingface#39622)
* revert behavior of _prepare_from_posids * add back cu_seqlens_k and max_k for inference
1 parent fc2502b commit 05b47e2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,16 +222,18 @@ def _prepare_from_posids(query, key, value, position_ids):
222222
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
223223
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
224224
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
225+
225226
cu_seqlens_k = torch.cat(
226227
[torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0
227228
)
228229
max_k = torch.max(position_ids, dim=1).values.max().item() + 1
230+
229231
position_ids = position_ids.flatten()
230232
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
231233

232234
cu_seq_lens = torch.cat(
233235
(
234-
torch.tensor([0], device=position_ids.device, dtype=torch.int32),
236+
indices_q[position_ids == 0],
235237
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
236238
)
237239
)

0 commit comments

Comments
 (0)