-
Notifications
You must be signed in to change notification settings - Fork 561
Fix: several bugs/issues with trtllm-gen attention kernels. #2062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: several bugs/issues with trtllm-gen attention kernels. #2062
Conversation
Summary of ChangesHello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request delivers crucial fixes for Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
WalkthroughUpdates TRTLLM FMHA artifact path and checksum constants; extends FMHA kernel hash encoding to include a new sparseMla flag with adjusted bit-field layout and stricter head-dimension checks; and adds paged-KV / sparse-related fields to KernelParams with zero-initialization and log2 computation for numTokensPerPage. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Runner as Runner / Dispatch
participant Selector as Kernel Selector
participant Meta as KernelMeta
participant Loader as Kernel Loader
Note over Runner,Selector: Build selection key from runtime params
Runner->>Selector: hashFromRunnerParams(params, /* sparseMla */ false)
Selector->>Meta: select candidate KernelMeta
Note right of Meta: KernelMeta includes mSparseMla
Selector->>Loader: hashID(kernelMeta, sparseMla=Meta.mSparseMla)
Loader->>Loader: assemble 64-bit hash (includes sparseMla bit, log2(numTokensPerPage))
Loader->>Runner: return selected kernel / load artifacts (uses updated artifact checksum)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates artifact hashes and refines the kernel selection logic for trtllm-gen attention kernels. Key changes include adding a sparseMla parameter to the hashID function, adjusting bit shifts for head dimensions, and enforcing that numTokensPerPage must be a power of 2. New members have been added to the KernelParams struct to support these changes, and the struct is now explicitly zero-initialized using memset for improved safety. These modifications appear to address the reported CUDA launch errors and masking bugs, enhancing the robustness and correctness of the attention kernels.
|
@PerkzZheng would you mind rebasing to main branch? Seems there are some merge conflicts. |
Signed-off-by: Perkz Zheng <[email protected]>
8dc0a1b to
e4d7f46
Compare
it was rebased to a wrong remote. It should be good now. Thanks |
pavanimajety
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR
nvmbreughe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just wondering: for what config did we get failures without this fix? I think it would be good to have a test. I can add it after this PR.
flashinfer/artifacts.py
Outdated
| ) | ||
| TRTLLM_GEN_BMM: str = ( | ||
| "46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd" | ||
| "1ebace613389a4f2e10b14315da5d522642c5dcaae23f01213d56c59068f148b" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to update the BMM hash in this PR?
|
/bot run |
|
[FAILED] Pipeline #38107936: 7/17 passed |
|
/bot run |
|
[FAILED] Pipeline #38135771: 14/17 passed |
📌 Description
This MR fixes:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Improvements
Chores