Skip to content

Conversation

@peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Aug 5, 2025

Commits in this PR

  1. Revert "[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr ([LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr #7748)"

    This reverts commit 40335eb.

  2. Reapply "[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr ([LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr #7748)"

    This reverts commit e6eb871.

  3. Re-enable float8 tensor memory

  4. Make shape errors more informative

  5. Respond to comments

PR chain

  1. 👉 [IR] Loosen tensor memory encoding checks added in #7748 #7784 👈 YOU ARE HERE

…tr (#7748)"

This reverts commit 40335eb.

git-pr-chain: revert_layouts_implement_tolinearlayout__9b92
Copy link
Contributor Author

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Narrowed this down a bit as I think some of the verification failures were genuine.

return emitError() << "bitwidth must be 16 or 32";
if (bitwidth > 32) {
return emitError() << "bitwidth must be <= 32";
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transmitting smaller dtypes through tmem seems to work fine, so removing this restriction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should error out ATM in ld/st?

if (auto attr = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
memType.getEncoding())) {
info.blockM = attr.getBlockM();
info.blockN = attr.getBlockN();
assert((!attr.getUnpacked() || info.numElementsPer32B <= 2) &&
"unsupported unpacked layout");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the tmem layout is packed, so that assert doesn't apply.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see, we can have packed for bitwidth=8

@peterbell10 peterbell10 changed the title Revert "[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr (#7748)" [IR] Loosen tensor memory encoding checks added in #7748 Aug 5, 2025
@peterbell10 peterbell10 enabled auto-merge (squash) August 5, 2025 20:58
@peterbell10 peterbell10 force-pushed the pb/pr-chain/revert_layouts_implement_tolinearlayout__9b92 branch from 4cd6aba to b3f7ce4 Compare August 5, 2025 21:24
Comment on lines +123 to +124
if (!enc.getUnpacked() && bitwidth > 16) {
return emitError() << "bitwidth must be <= 16 for packed tensor memory";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was something I wanted to make more strict regardless. Any chance we could change the use case to leave it as it is?

Copy link
Contributor Author

@peterbell10 peterbell10 Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at the use case and it's passing fp8 as the lhs in tmem for a tcgen05_mma so there is no other way to do it. We should just fix load and store if they really are broken.

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the two changes in the verifier... I did mean to write those. In particular the bitwidth == 8 case was failing in a test

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we then just support bitwidth=8 for the packed case? Otherwise lgtm

return emitError() << "bitwidth must be 16 or 32";
if (bitwidth > 32) {
return emitError() << "bitwidth must be <= 32";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see, we can have packed for bitwidth=8

return emitError() << "blockM must be 64 or 128";
return emitError() << "blockM must be 64 or 128 but got " << blockM;
}
if (!llvm::isPowerOf2_32(blockN) || blockN > 512) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but I just realised that this should be blockN < 512 * (isUnpacked ? (32 / bitwidth) : 1). I can add it to another PR if you don't want to fix it in this one tho.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could probably remove this restriction as it is more a restriction on allocation size

return emitError() << "blockM must be 64 or 128";
return emitError() << "blockM must be 64 or 128 but got " << blockM;
}
if (!llvm::isPowerOf2_32(blockN) || blockN > 512) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could probably remove this restriction as it is more a restriction on allocation size

@peterbell10 peterbell10 merged commit bbdbbd1 into main Aug 6, 2025
9 checks passed
@peterbell10 peterbell10 deleted the pb/pr-chain/revert_layouts_implement_tolinearlayout__9b92 branch August 6, 2025 02:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants