Skip to content

Commit b3f7ce4

Browse files
committed
Re-enable float8 tensor memory
1 parent 701d0a2 commit b3f7ce4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
120120
return emitError() << "rank must be 2 or 3";
121121
}
122122
auto bitwidth = elementType.getIntOrFloatBitWidth();
123-
if (!enc.getUnpacked() && bitwidth != 16) {
124-
return emitError() << "bitwidth must be 16 for packed tensor memory";
123+
if (!enc.getUnpacked() && bitwidth > 16) {
124+
return emitError() << "bitwidth must be <= 16 for packed tensor memory";
125125
}
126-
if (bitwidth != 16 && bitwidth != 32) {
127-
return emitError() << "bitwidth must be 16 or 32";
126+
if (bitwidth > 32) {
127+
return emitError() << "bitwidth must be <= 32";
128128
}
129129
shape = shape.take_back(2);
130130
allocShape = allocShape.take_back(2);

0 commit comments

Comments
 (0)