-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[IR] Loosen tensor memory encoding checks added in #7748 #7784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Loosen tensor memory encoding checks added in #7748 #7784
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Narrowed this down a bit as I think some of the verification failures were genuine.
| return emitError() << "bitwidth must be 16 or 32"; | ||
| if (bitwidth > 32) { | ||
| return emitError() << "bitwidth must be <= 32"; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transmitting smaller dtypes through tmem seems to work fine, so removing this restriction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should error out ATM in ld/st?
triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp
Lines 254 to 259 in dd58234
| if (auto attr = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>( | |
| memType.getEncoding())) { | |
| info.blockM = attr.getBlockM(); | |
| info.blockN = attr.getBlockN(); | |
| assert((!attr.getUnpacked() || info.numElementsPer32B <= 2) && | |
| "unsupported unpacked layout"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case the tmem layout is packed, so that assert doesn't apply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see, we can have packed for bitwidth=8
4cd6aba to
b3f7ce4
Compare
| if (!enc.getUnpacked() && bitwidth > 16) { | ||
| return emitError() << "bitwidth must be <= 16 for packed tensor memory"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was something I wanted to make more strict regardless. Any chance we could change the use case to leave it as it is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking at the use case and it's passing fp8 as the lhs in tmem for a tcgen05_mma so there is no other way to do it. We should just fix load and store if they really are broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the two changes in the verifier... I did mean to write those. In particular the bitwidth == 8 case was failing in a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we then just support bitwidth=8 for the packed case? Otherwise lgtm
| return emitError() << "bitwidth must be 16 or 32"; | ||
| if (bitwidth > 32) { | ||
| return emitError() << "bitwidth must be <= 32"; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see, we can have packed for bitwidth=8
| return emitError() << "blockM must be 64 or 128"; | ||
| return emitError() << "blockM must be 64 or 128 but got " << blockM; | ||
| } | ||
| if (!llvm::isPowerOf2_32(blockN) || blockN > 512) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated, but I just realised that this should be blockN < 512 * (isUnpacked ? (32 / bitwidth) : 1). I can add it to another PR if you don't want to fix it in this one tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could probably remove this restriction as it is more a restriction on allocation size
| return emitError() << "blockM must be 64 or 128"; | ||
| return emitError() << "blockM must be 64 or 128 but got " << blockM; | ||
| } | ||
| if (!llvm::isPowerOf2_32(blockN) || blockN > 512) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could probably remove this restriction as it is more a restriction on allocation size
Commits in this PR
Revert "[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr ([LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr #7748)"
This reverts commit 40335eb.
Reapply "[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr ([LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr #7748)"
This reverts commit e6eb871.
Re-enable float8 tensor memory
Make shape errors more informative
Respond to comments
PR chain