Fix: Correct/Improve the triton attention kernel #196
+232
−22
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces several enhancements to the attention kernel, including the implementation of a backward pass, memory optimization for grouped query attention, and a bug fix.
1. Bug Fix: Incorrect Attention with Query Offset: Fixed a bug where the attention kernel produced incorrect results when the query offset (
start_q
) was non-zero. The kernel's starting loop bound (lo
) was incorrectly initialized tostart_q
, causing the computation to skip the initial keys in the KV cache.2. Improve GQA Memory Optimization: The K and V tensors were explicitly expanded using
torch.repeat_interleave
, which materialized large tensors in memory. It is better to handle it by manipulating pointers to map query heads to their corresponding KV head3. Backward Pass Implementation: Implement a custom backward pass, making the module fully differentiable and usable for end-to-end model training.
Testing: All
pytest
cases have been updated to validate both the forward and backward passes against a reference PyTorch implementation: