Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 25 additions & 53 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,13 @@ def _upad_input(
)


def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True):
def prepare_fa_kwargs_from_position_ids(position_ids):
"""
This function returns all the necessary kwargs to call `flash_attn_varlen_func`
extracted from position_ids. The `position_ids` can be either packed sequence or
the usual padded position ids, for example in inference time.
This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids.

Arguments:
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
is_packed_sequence (`bool`, *optional*, defaults to `True`):
Whether the input position ids are a packed sequence or not.

Return:
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
Expand All @@ -333,52 +329,35 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool =
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
# If the lengths are not equal, most probably we are in decoding stage with cache
# In that case the position ids will not always start with `0` and we need a better way to infer
# cumulative seq lengths.
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
if not is_packed_sequence:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path was meant for generation and is no longer valid, it's a dead path that is no longer used as of #40161

last_position_ids = position_ids[:, -1]
q_len = (
torch.ones(position_ids.size(0), **tensor_kwargs)
if position_ids.shape[-1] == 1
else last_position_ids.add(1)
)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0)
cu_seq_lens_k = torch.cat(
[torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0
)

max_length_q = int(q_len.max())
max_length_k = int(last_position_ids.max()) + 1
else:
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)

cu_seq_lens_q = torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
cu_seq_lens_q = torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
cu_seq_lens_k = cu_seq_lens_q

# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length_q = cu_seq_lens_q.diff().max()
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
max_length_q = max_length_q.item()
max_length_k = max_length_q
)
cu_seq_lens_k = cu_seq_lens_q

# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length_q = cu_seq_lens_q.diff().max()
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
max_length_q = max_length_q.item()
max_length_k = max_length_q

return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)


def _prepare_from_posids(query, key, value, position_ids, query_length):
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Expand All @@ -394,8 +373,6 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Sequence length of the input queries.

Return:
query (`torch.Tensor`):
Expand All @@ -409,16 +386,11 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
kv_length = key.shape[1]
is_packed_sequence = query_length == kv_length

query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))

(cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
position_ids, is_packed_sequence=is_packed_sequence
)
(cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(position_ids)

return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))

Expand Down Expand Up @@ -660,7 +632,7 @@ def _flash_attention_forward(
elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
if cu_seq_lens_q is None or cu_seq_lens_k is None:
q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
query_states, key_states, value_states, position_ids, query_length=query_length
query_states, key_states, value_states, position_ids
)
else:
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
Expand Down
6 changes: 4 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4313,8 +4313,10 @@ def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa
@mark.flash_attn_test
def test_flash_attention_2_continue_generate_with_position_ids(self):
"""
Tests that the given attention implementation can work with packed sequences and infers the mask
from position ids. This test requires the model to use new attention mask API which handles packing.
Tests whether flash attention can continue its generation from given position ids.

NOTE: This serves as regression check as we had instances where flash attention entered the varlen
path here. It should now always enter the base `flash_fn`.
"""

max_new_tokens = 2
Expand Down