Skip to content

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Aug 25, 2025

Parts of the flash attention generation was moved to the generate preparation and was reverted in #40161

The prepare in generation had more ripple effects tho:

This PR reverts these changes completely to be aligned with the changes in #40161:

Additional context

The long version explanation:

The short version explanation:

  • Flash attention uses the base flash_fn when we have no padding making anything related to varlen combined with generate obsolete

Fixes #40399
Closes #40412

cc @ArthurZucker @zucchini-nlp @Cyrilvallez

tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
if not is_packed_sequence:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path was meant for generation and is no longer valid, it's a dead path that is no longer used as of #40161

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

Can you check if this is working for you, as for #39814? @maxjeblick @alessiodevoto

I ran the reproducer and can no longer get the error so I assume it's not an issue even with these changes here.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deleted path was meant to help Nvidia KVPress or users that perform manual generation loop to use FA2 (#39814). Otherwise the cu_seq_lens always assume that we have a packed sequence in input. We need to keep it somewhere and allow users to do custom generation with FA2

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

Trying to give a history of what happened and why ultimately don't need that path anymore:

  • Kernels flash attn #39474 introduces a refactor for FA integrating kernels
    • Preparation during generation is introduced
      if "flash" in self.config._attn_implementation and self._supports_attention_backend:
      tensor_kws = {"dtype": torch.int32, "device": self.device}
      pos = model_inputs["position_ids"][:, -1]
      cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
      max_length_k = int(pos.max()) + 1
      bs, seq_len = input_ids.size()
      q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
      cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
      max_length_q = int(q_len.max())
      model_inputs.update(
      cu_seq_lens_q=cu_seq_lens_q.to(self.device),
      cu_seq_lens_k=cu_seq_lens_k.to(self.device),
      max_length_q=max_length_q,
      max_length_k=max_length_k,
      )
    • We always enter this path (when no mask) as the fa kwargs are prepared
      use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
    • This has not changed but it is important that if we provide the kwargs ourselves we need to be sure that everything is correctly prepared, i.e.
      # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
      # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
  • [FA2] Fix it finally - revert fa kwargs preparation #40161 removes this as it does not have real benefits and causes more breaks instead
  • Own attempt at what happens in our fa forward:
    • If we have an attention mask then we always enter the first path at
      if attention_mask is not None:
      q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
      query_states, key_states, value_states, attention_mask, query_length, unpad_fn
      )
      # TODO for now this is required to work with
      # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
      if "mps" in str(q.device):
      cu_seq_lens_k = cu_seq_lens_k.clone()
      out_unpad = flash_varlen_fn(
      q,
      k,
      v,
      cu_seqlens_q=cu_seq_lens_q,
      cu_seqlens_k=cu_seq_lens_k,
      max_seqlen_q=max_length_q,
      max_seqlen_k=max_length_k,
      **flash_kwargs,
      )
      if isinstance(out_unpad, tuple):
      out_unpad = out_unpad[0]
      out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
      (this did not change with the refactors)
    • If no padding or batches of the same length ( == no attention mask (needed))
  • This PR removes these artefacts that were based on Kernels flash attn #39474 and Flash Attention fails with non aligned position_ids #39814 as varlen for generation is no longer valid.
    • We should not enter the varlen with no padding.

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

@zucchini-nlp Tried to explain things above ^

tl;dr: We no longer enter varlen during generate (except when we use attention mask which hasn't changed) - this was inefficient and broke more things; this cleans the rest up connected to #40161

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Aug 25, 2025

We no longer enter varlen during generate

Sorry, I am a bit lazy to read through all PRs 🙃

Just to make sure, if users have a custom generation where they call forward several times, which path of FA2 does it lead to? Do we still take non-varlen path if attention mask isn't provided in forward call? AFAIK it all depended on presence of attention mask, which was the reason above linked issue failed in KVPress

If the linked issue isn't reproducible anymore, it should be fine. But I'd like us to have a test to avoid regression

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

Sorry, I am a bit lazy to read through all PRs 🙃

No worries 😆

Just to make sure, if users have a custom generation where they call forward several times, which path of FA2 does it lead to?

Depends on the input either a) No padding then we enter the basic fa function path (last in if else) b) Varlen path with attention mask where we manually do things (hasn't changed)

Do we still take non-varlen path when attention mask isn't provided in forward call? AFAIK it all depended on presence of attention mask, which was the reason above linked issue failed in KVPress

We circumvent that issue entirely by not going the varlen path here. It wasn't necessary to enter varlen here as we have no padding and we made our lives significantly harder with this 😓 When we entered the varlen path (for input with no padding), we needed that workaround that you made but with the removal of the prep during generate, we no longer need it.

If the linked issue isn't reproducible anymore, it should be fine. But I'd like us to have a test to avoid regression

I can revert the removal of the test - it should catch that error in case we decide to change things up there

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but:

  • let's make sure we don't break ulyss as well
  • let's check our run slow on this PR the docker should ahve fa2 now
  • happy to have a small TLDR because I did not know either that now padding -> flash_fn instead (did not know it made a difference)

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

run-slow: llama,mistral,bart

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/bart', 'models/llama', 'models/mistral']
quantizations: [] ...

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

Can you check if this PR works for you @ETOgaosion @kisseternity? (ulysses-sp)

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

run-slow: llama,mistral,bart

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/bart', 'models/llama', 'models/mistral']
quantizations: [] ...

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

Identical failures to main (+ the dola tests which are known)

Waiting for feedback on ulysses and kvpress, then I'd merge

@maxjeblick
Copy link

Thanls for the heads up, I'll report later today.

@ETOgaosion
Copy link

I think it works for verl's ulysses patch, we use a special handling method for current 4.55 query length API, luckily _flash_attn_forward still has this API unchanged, so it's compatible.

@vasqu
Copy link
Contributor Author

vasqu commented Aug 25, 2025

Gotcha @ETOgaosion, I'd be interested if this worked without the patch for 4.55? I assume this PR is safe then on your side

@maxjeblick
Copy link

Thanks a lot for the PR; from kvpress side, there are no issues!

@kisseternity
Copy link

Can you check if this PR works for you @ETOgaosion @kisseternity? (ulysses-sp)

Indeed I'm using https://github.com/huggingface/transformers/pull/40412/files for a quick fix and it looks good so far, when the training is done I'll give a try.

@vasqu
Copy link
Contributor Author

vasqu commented Aug 26, 2025

@kisseternity The problem is that #40412 will use logic that is no longer valid for us, so this PR will supersede #40412. Just want to make sure that I don't break things here instead as well 👀

@vasqu
Copy link
Contributor Author

vasqu commented Aug 26, 2025

Thanks for checking @maxjeblick 🤗

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Aug 27, 2025

Thanks for reverting the remaining dead code @vasqu! Indeed, we should NEVER take varlen path when we don't have attention mask or native packed format! This was a mistake that it was ever added

@vasqu vasqu merged commit 7e1aee4 into huggingface:main Aug 28, 2025
24 checks passed
@vasqu vasqu deleted the fa-cleanup branch August 28, 2025 13:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[API] Current query_length API can break ulysses-sp patch
8 participants