Skip to content

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Jul 17, 2025

What does this PR do?

pip install transformers[torch] kernels
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    device_map="auto",
    torch_dtype="auto",
    attn_implementation="flash_attention_2",
).eval()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
    ["Hello, how are you?", "is this life?"],
    padding=True,
    padding_side="left",
    return_tensors="pt",
).to(model.device)


start = time.time()
outputs = model.generate(**inputs, max_new_tokens=50)
print(f"Generation time: {time.time() - start:.2f} seconds")
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

model.set_attn_implementation("kernels-community/flash-attn3")
start = time.time()
outputs = model.generate(**inputs, max_new_tokens=50)
print(f"Generation time: {time.time() - start:.2f} seconds")
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kadirnar
Copy link
Contributor

tf install:

uv pip install git+https://github.com/huggingface/transformers.git@kernels-flash-attn

env:

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- `transformers` version: 4.54.0.dev0
- Platform: Linux-6.11.0-29-generic-x86_64-with-glibc2.40
- Python version: 3.10.16
- Huggingface_hub version: 0.33.4
- Safetensors version: 0.5.3
- Accelerate version: 1.9.0
- Accelerate config:    - compute_environment: LOCAL_MACHINE

Error Message:

[rank0]: ValueError: Specified `attn_implementation="https://huggingface.co/kernels-community/flash-attn3:flash_attention"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation), `"attn_implementation=flash_attention_3"` (implementation using flash attention 3), `"attn_implementation=flash_attention_2"` (implementation using flash attention 2), `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention), `"attn_implementation=flex_attention"` (implementation using torch's flex_attention).

@ArthurZucker
Copy link
Collaborator Author

You can't pass the full http! You need to pass kernels-community/flash-attn3:flash_attention

@kadirnar
Copy link
Contributor

You can't pass the full http! You need to pass kernels-community/flash-attn3:flash_attention

I think writing the URL is silly too. However, since you shared it like this on Twitter, I gave it a try.
https://x.com/art_zucker/status/1945821883858915695

New Error Message:

[rank1]:     cache_position = torch.arange(
[rank1]: RuntimeError: CUDA error: device-side assert triggered
[rank1]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank1]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank1]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

] Assertion `srcIndex < srcSelectDimSize,0,0: indexSelectLargeIndex:1553: indexSelectLargeIndex` failed.
,0,0: block: [237: indexSelectLargeIndex: block: [569/pytorch/aten/src/ATen/native/cuda/Indexing.cu], thread: [96], thread: [69,0: block: [173,0:1553,0,0,0,0,0: indexSelectLargeIndex,0,0], thread: [49,0], thread: [62: block: [844] Assertion `srcIndex < srcSelectDimSize] Assertion `srcIndex < srcSelectDimSize,0], thread: [45,0,0` failed.
` failed.

Should I wait for you to finish your development?

@ArthurZucker
Copy link
Collaborator Author

Ah that's weird can you share a small reproducer?

@ArthurZucker
Copy link
Collaborator Author

run-slow: llama,mistral,gemma

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/gemma', 'models/llama', 'models/mistral']
quantizations: [] ...

@kadirnar
Copy link
Contributor

@ArthurZucker I tried it with a different LLM model, and it worked. It seems that the dataset of the Qwen model is faulty. I will fix this and provide feedback on the performance.

@ArthurZucker
Copy link
Collaborator Author

Thanks @kadirnar !

@ArthurZucker
Copy link
Collaborator Author

TRANSFORMERS_TEST_DEVICE="mps" RUN_SLOW=1 pytest tests/models/llama/test_modeling_llama.py -k kernels_m
ps -s

added a new test for MPS!

@kadirnar
Copy link
Contributor

kadirnar commented Aug 2, 2025

@ArthurZucker This method only supports LLM models, right? What should we do to add kernel support for speech models?

Example: https://huggingface.co/docs/transformers/main/en/model_doc/dia

@ArthurZucker
Copy link
Collaborator Author

This should be supported by all models as long as they have a the ALL_ATTENTION_FUNCTION refactor + you can set the attention for sub_modules!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants