Skip to content

Commit efceeaf

Browse files
ArthurZuckervasquzucchini-nlpstevhliu
authored
Kernels flash attn (#39474)
* use partial to wrap around `transformers` utils! * try to refactor? * revert one wrong change * just a nit * push * reverter watever was wrong! * some nits * fixes when there is no attention mask * bring the licence back * some fixes * nit * style * remove prints * correct dtype * fa flags for testing * update * use paged attention if requested! * updates * a clone was needed, not sure why * automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute * simplify and improve? * flash attention is kinda broken on recent cuda version so allow the opportunity to use something else * fix! * protect kernels import * update * properly parse generation config being passed * revert and update * add two tests * some fixes * fix test FA2 * takes comment into account * fixup * revert changes * revert the clone, it is only needed because the metal kernel is not doing it? * [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * applu suggestions --------- Co-authored-by: Steven Liu <[email protected]> * fix mps on our side for now * Update src/transformers/integrations/flash_paged.py * no qa --------- Co-authored-by: Vasqu <[email protected]> Co-authored-by: Raushan Turganbay <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent b62557e commit efceeaf

File tree

9 files changed

+330
-415
lines changed

9 files changed

+330
-415
lines changed

src/transformers/generation/continuous_batching.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,8 @@ def __init__(
11191119
self._request_lock = threading.Lock()
11201120
self.model.generation_config.top_p = None
11211121
self.do_sample = getattr(generation_config, "do_sample", True)
1122-
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
1122+
generation_config = model.generation_config if generation_config is None else generation_config
1123+
self.logit_processor = self.model._get_logits_processor(generation_config)
11231124
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
11241125
self.profile = getattr(generation_config, "profile", False)
11251126
self.manual_eviction = manual_eviction

src/transformers/generation/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,24 @@ def prepare_inputs_for_generation(
677677
if encoder_attention_mask is not None:
678678
model_inputs["attention_mask"] = encoder_attention_mask
679679

680+
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
681+
tensor_kws = {"dtype": torch.int32, "device": self.device}
682+
pos = model_inputs["position_ids"][:, -1]
683+
684+
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
685+
max_length_k = int(pos.max()) + 1
686+
687+
bs, seq_len = input_ids.size()
688+
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
689+
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
690+
max_length_q = int(q_len.max())
691+
692+
model_inputs.update(
693+
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
694+
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
695+
max_length_q=max_length_q,
696+
max_length_k=max_length_k,
697+
)
680698
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
681699
for key, value in kwargs.items():
682700
if key not in model_inputs:

src/transformers/integrations/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def flash_attention_forward(
3838
"FlashAttention does not support inputs with dim=0.\n"
3939
"Please check your input shapes or use SDPA instead."
4040
)
41-
4241
# FA2 uses non-transposed inputs
4342
query = query.transpose(1, 2)
4443
key = key.transpose(1, 2)
@@ -76,6 +75,7 @@ def flash_attention_forward(
7675
use_top_left_mask=_use_top_left_mask,
7776
target_dtype=target_dtype,
7877
attn_implementation=module.config._attn_implementation,
78+
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
7979
**kwargs,
8080
)
8181

src/transformers/integrations/flash_paged.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
if is_flash_attn_2_available():
8-
from flash_attn import flash_attn_varlen_func
8+
from flash_attn import flash_attn_varlen_func # noqa: F401
99

1010

1111
def paged_attention_forward(
@@ -20,6 +20,7 @@ def paged_attention_forward(
2020
max_seqlen_q=None,
2121
max_seqlen_k=None,
2222
block_tables=None,
23+
implementation=None,
2324
**kwargs,
2425
) -> torch.Tensor:
2526
r"""Perform the forward pass of attention with paged key-value cache.
@@ -46,12 +47,14 @@ def paged_attention_forward(
4647
"""
4748
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
4849

50+
if implementation is not None:
51+
flash_attn_varlen_func = implementation.flash_attn_varlen_func
4952
attn_output = flash_attn_varlen_func(
50-
q.transpose(1, 2).squeeze(0),
51-
k.transpose(1, 2).squeeze(0),
52-
v.transpose(1, 2).squeeze(0),
53+
q.transpose(1, 2).squeeze(0).contiguous(),
54+
k.transpose(1, 2).squeeze(0).contiguous(),
55+
v.transpose(1, 2).squeeze(0).contiguous(),
5356
cumulative_seqlens_q.to(torch.int32),
54-
cumulative_seqlens_k.to(torch.int32),
57+
cumulative_seqlens_k.to(torch.int32).clone(),
5558
max_seqlen_q,
5659
max_seqlen_k,
5760
softmax_scale=module.scaling,

0 commit comments

Comments
 (0)