Skip to content

Conversation

EduardDurech
Copy link
Contributor

@EduardDurech EduardDurech commented Jun 23, 2025

Supports Flash Attention 3 for _flash_attention_forward

Previous #36190 @ArthurZucker

Parity test Flash Attention {2,3} based on https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py

$ RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
> ============================================================================================================ test session starts ============================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /workspace/transformers
configfile: pyproject.toml
plugins: hydra-core-1.3.2, xdist-3.6.1, rerunfailures-15.1, hypothesis-6.130.8, shard-0.1.2, xdoctest-1.0.2, flakefinder-1.1.0, anyio-4.9.0, typeguard-4.3.0
collected 1 item                                                                                                                                                                                                                            
Running 1 items in this shard

tests/generation/test_flash_attention_parity.py::FlashAttentionParityTest::test_flash_attention_2_3_parity You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.

--- Flash Attention (2, 3) Parity Test on meta-llama/Llama-3.2-1B-Instruct ---
Prompt: 'The ETH AI Center is'
Generated text with Flash Attention 2: The ETH AI Center is a research center that focuses on the development of artificial intelligence and its applications in various fields. The center
Generated text with Flash Attention 3: The ETH AI Center is a research center that focuses on the development of artificial intelligence and its applications in various fields. The center
ROUGE-L: 1.0
Max absolute difference in logprobs: 0.00000e+00
Flash Attention 2 latency: 287.42 ms
Flash Attention 3 latency: 272.10 ms
Speed-up: 1.06x
---
PASSED

============================================================================================================= warnings summary ==============================================================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.

../../usr/local/lib/python3.12/dist-packages/google/protobuf/internal/well_known_types.py:93
  /usr/local/lib/python3.12/dist-packages/google/protobuf/internal/well_known_types.py:93: DeprecationWarning: datetime.datetime.utcfromtimestamp() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.fromtimestamp(timestamp, datetime.UTC).
    _EPOCH_DATETIME_NAIVE = datetime.datetime.utcfromtimestamp(0)

../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1439
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: asyncio_default_fixture_loop_scope
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================= 1 passed, 4 warnings in 8.18s =======================================================================================================

Closes #32219, #33373

Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on huggingface#36190 which has model implementations and examples which could be merged
@EduardDurech EduardDurech marked this pull request as draft June 23, 2025 00:14
@EduardDurech EduardDurech force-pushed the FA3 branch 3 times, most recently from 629deca to 230f64f Compare June 23, 2025 01:28
@EduardDurech EduardDurech marked this pull request as ready for review June 23, 2025 01:51
@github-actions github-actions bot requested review from ArthurZucker and ydshieh June 23, 2025 01:51
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for taking the time!

- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing
@EduardDurech
Copy link
Contributor Author

@ArthurZucker all comments resolved

@EduardDurech
Copy link
Contributor Author

Re: @tridao you mentioned dropout may be supported pytorch/pytorch#148891 (comment) if I could be pinged when that's done I can submit a new PR

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just deprecate the one method you are renaming and good to go!

@ArthurZucker ArthurZucker merged commit a2eb75c into huggingface:main Jun 25, 2025
18 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks a lot for the contribution! 🤗

@1ytic
Copy link
Contributor

1ytic commented Jun 27, 2025

@EduardDurech thank you for the contribution! I'm trying to use it with FSDP2, but get this error:

NotImplementedError: flash_attn_3::fwd: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered.

It's triggered from here. Any ideas how to fix this? Is it even possible?

@1ytic
Copy link
Contributor

1ytic commented Jun 28, 2025

Feel free to ignore my prev question. I missed that my tp plan uses ColwiseParallel(use_local_output=False) for query/key/value states. Switching to standard torch tensors works fine.

@EduardDurech
Copy link
Contributor Author

Haven't tested with FSDP2 but glad you got it sorted out 😃

@ArthurZucker
Copy link
Collaborator

@1ytic We want it to work with ColwiseParellel actually if you have a reproducer can you open an issue

@EduardDurech
Copy link
Contributor Author

@1ytic @ArthurZucker this shouldn't be too difficult to fix, I won't have the time but if anyone wants it seems need to

This is more a low-level flash_attn_3 and PyTorch thing but it seems possible to patch in Transformers

@EduardDurech
Copy link
Contributor Author

btw, anyone using Ascend NPU see #39166, thanks @FightingZhen

@EduardDurech
Copy link
Contributor Author

EduardDurech commented Jul 7, 2025

@1ytic @ArthurZucker this shouldn't be too difficult to fix, I won't have the time but if anyone wants it seems need to

* Register fake/meta kernel, see https://gist.github.com/a-r-r-o-w/d08c37e8bd3e9c26b4ce80360be148c6#file-benchmark_kontext_cp-py-L169

* Create a flash_fwd DTensor wrapper, see https://dev-discuss.pytorch.org/t/dtensor-status-design-and-looking-forward/2749

* Include original and DTensor within dispatcher

This is more a low-level flash_attn_3 and PyTorch thing but it seems possible to patch in Transformers

Following up, @1ytic What PyTorch version are you using? Does ColwiseParallel(use_local_output=True) work?

Maybe try the first point it may be enough

import torch
from flash_attn_interface import flash_attn_func as flash_attn_3_func

@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _fake_fa3(q,k,v,*,is_causal=False):
    B,S,H,D = q.shape
    return torch.empty_like(q), q.new_empty((B,S,H))

@EduardDurech
Copy link
Contributor Author

Maybe @a-r-r-o-w could update

@a-r-r-o-w
Copy link
Contributor

I'm not sure of the approach to be followed since I haven't tried FA3 with Pytorch TP with DTensor. I think what @EduardDurech mentioned in his comment sounds good. We might not need anything DTensor specific here, if my memory serves right, from similar tests with SageAttention, and maybe just the meta registration will allow it to work with FSDP2.

Fake/meta registrations should probably live within flash-attn (there's a PR: Dao-AILab/flash-attention#1590) but for the time being, it could maybe be added to transformers if the maintainers are okay. Without the registration, torch compile should also be failing with FA3, so it's important to have.

@1ytic
Copy link
Contributor

1ytic commented Jul 7, 2025

This is more a low-level flash_attn_3 and PyTorch thing but it seems possible to patch in Transformers

Agree, it should be done on flash_attn side.

Does ColwiseParallel(use_local_output=True) work?

Yes, it works.

Just for context, I tried to use NeMo-RL for Qwen3 model with this tp plan. But with flash_attention_3 I changed it to this:

        base_model_tp_plan = {
            "lm_head": ColwiseParallel(
                input_layouts=Shard(1),
                output_layouts=Shard(-1),
                use_local_output=False,
            ),
            "model.embed_tokens": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
            ),
            "model.rotary_emb": RotaryEmbedParallel(use_local_output=True),
            "model.norm": SequenceParallel(),
            "model.layers.*.input_layernorm": SequenceParallel(),
            "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False),
            "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False),
            "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True),
            "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
            "model.layers.*.self_attn.q_norm": Qwen3QKNorm(use_local_output=True),
            "model.layers.*.self_attn.k_norm": Qwen3QKNorm(use_local_output=True),
            "model.layers.*.post_attention_layernorm": SequenceParallel(),
            "model.layers.*.mlp.up_proj": ColwiseParallel(),
            "model.layers.*.mlp.gate_proj": ColwiseParallel(),
            "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
        }

@ArthurZucker
Copy link
Collaborator

Can probably be added easily to the flash attention kernel on the hub if you need it fast! https://huggingface.co/kernels-community/flash-attn3 if you want to open a PR there!

@kisseternity
Copy link

Thanks for supporting fa3. Now I'm using fa3 with Ulysses sp, but it turns out the forward logits are nan for most time. Could you pls check if fa3 is good with Ulysses or not?

@ArthurZucker
Copy link
Collaborator

Yep we have a pr for that: #40412 tell us if it fixes!

@kisseternity
Copy link

Yep we have a pr for that: #40412 tell us if it fixes!

Impressive! It works, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

flashattention3
5 participants