-
Notifications
You must be signed in to change notification settings - Fork 30.4k
Support for Flash Attention 3 #38972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper - Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...` An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated Based on huggingface#36190 which has model implementations and examples which could be merged
629deca
to
230f64f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for taking the time!
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids` - Remove bettertransformer check in Flash Attention 3 - Merge tests - Add licensing
@ArthurZucker all comments resolved |
Re: @tridao you mentioned dropout may be supported pytorch/pytorch#148891 (comment) if I could be pinged when that's done I can submit a new PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just deprecate the one method you are renaming and good to go!
Thanks a lot for the contribution! 🤗 |
@EduardDurech thank you for the contribution! I'm trying to use it with FSDP2, but get this error:
It's triggered from here. Any ideas how to fix this? Is it even possible? |
Feel free to ignore my prev question. I missed that my tp plan uses |
Haven't tested with FSDP2 but glad you got it sorted out 😃 |
@1ytic We want it to work with |
@1ytic @ArthurZucker this shouldn't be too difficult to fix, I won't have the time but if anyone wants it seems need to
This is more a low-level flash_attn_3 and PyTorch thing but it seems possible to patch in Transformers |
btw, anyone using Ascend NPU see #39166, thanks @FightingZhen |
Following up, @1ytic What PyTorch version are you using? Does Maybe try the first point it may be enough import torch
from flash_attn_interface import flash_attn_func as flash_attn_3_func
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _fake_fa3(q,k,v,*,is_causal=False):
B,S,H,D = q.shape
return torch.empty_like(q), q.new_empty((B,S,H)) |
Maybe @a-r-r-o-w could update |
I'm not sure of the approach to be followed since I haven't tried FA3 with Pytorch TP with DTensor. I think what @EduardDurech mentioned in his comment sounds good. We might not need anything DTensor specific here, if my memory serves right, from similar tests with SageAttention, and maybe just the meta registration will allow it to work with FSDP2. Fake/meta registrations should probably live within |
Agree, it should be done on flash_attn side.
Yes, it works. Just for context, I tried to use NeMo-RL for Qwen3 model with this tp plan. But with flash_attention_3 I changed it to this: base_model_tp_plan = {
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1),
use_local_output=False,
),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"model.rotary_emb": RotaryEmbedParallel(use_local_output=True),
"model.norm": SequenceParallel(),
"model.layers.*.input_layernorm": SequenceParallel(),
"model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True),
"model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"model.layers.*.self_attn.q_norm": Qwen3QKNorm(use_local_output=True),
"model.layers.*.self_attn.k_norm": Qwen3QKNorm(use_local_output=True),
"model.layers.*.post_attention_layernorm": SequenceParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
} |
Can probably be added easily to the flash attention kernel on the hub if you need it fast! https://huggingface.co/kernels-community/flash-attn3 if you want to open a PR there! |
Thanks for supporting fa3. Now I'm using fa3 with Ulysses sp, but it turns out the forward logits are nan for most time. Could you pls check if fa3 is good with Ulysses or not? |
Yep we have a pr for that: #40412 tell us if it fixes! |
Impressive! It works, thanks! |
Supports Flash Attention 3 for
_flash_attention_forward
Previous #36190 @ArthurZucker
Parity test Flash Attention {2,3} based on https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py
Closes #32219, #33373