-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][arith.constant]Fix element type of the dense attributes in target attributes to be consistent with result type in LLVM::detail::oneToOneRewrite() #149787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Mengmeng Sun (MengmSun) ChangesAs I described in [MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant, we have a case as below after ...
%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>
%10 = vector.extract %8[0] : vector<192xi8> from vector<1x192xi8>
... Our next pass is mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() && "expected the same element type"' failed. That's because Before we try to just WAR in So in this MR, we tried to fix element type of the dense attributes in target attributes to be consistent with result type. In this fix, the Full diff: https://github.com/llvm/llvm-project/pull/149787.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..329703e4f054d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
return failure();
}
+ // If the targetAttrs contains DenseElementsAttr,
+ // and the element type of the DenseElementsAttr and result type is
+ // inconsistent after the conversion of result types, we need to convert the
+ // element type of the DenseElementsAttr to the target type by creating a new
+ // DenseElementsAttr with the converted element type, and use the new
+ // DenseElementsAttr to replace the old one in the targetAttrs
+ SmallVector<NamedAttribute> convertedAttrs;
+ for (auto attr : targetAttrs) {
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+ VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
+ if (vectorType) {
+ auto convertedElementType =
+ typeConverter.convertType(vectorType.getElementType());
+ VectorType convertedVectorType =
+ VectorType::get(vectorType.getShape(), convertedElementType,
+ vectorType.getScalableDims());
+ convertedAttrs.emplace_back(
+ attr.getName(), DenseElementsAttr::getFromRawBuffer(
+ convertedVectorType, denseAttr.getRawData()));
+ }
+ } else {
+ convertedAttrs.push_back(attr);
+ }
+ }
+
// Create the operation through state since we don't know its C++ type.
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
- resultTypes, targetAttrs);
+ resultTypes, convertedAttrs);
setNativeProperties(newOp, overflowFlags);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 83bdbe1f67118..299cc32351bdb 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
// CHECK-LABEL: @index_vector
func.func @index_vector(%arg0: vector<4xindex>) {
- // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
+ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
%0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
%1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
// -----
+// CHECK-LABEL: @f8E4M3FN_vector
+func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
+ // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
+ %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
+ // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+ %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
+ // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+ %2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
+ // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
+ func.return %2 : vector<4xf8E4M3FN>
+}
+
+// -----
+
// CHECK-LABEL: @bitcast_1d
func.func @bitcast_1d(%arg0: vector<2xf32>) {
// CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a general comment, I don't think mixing llvm.mlir.constant
and operations like vector.shape_cast
makes sense, and since llvm.mlir.constant
allows for a type/attribute mismatch (to my knowledge), there might not be anything to fix here.
convertedVectorType, denseAttr.getRawData())); | ||
} | ||
} else { | ||
convertedAttrs.push_back(attr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this also be done for scalar attributes? Ex. a index constant becoming an i64 constant?
I don't think mixing is illegal. // shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType()); So I think modifying this part is more reasonable. It's not maintaining more codes for incorrect codes. Instead, it categorizes the possible cases that may be encountered with |
Well, no, the point I'm making is that LLVM lowerings should be done in one shot without intervening canonicalizations, because llvm.mlir.constant is one of the few cases where this sort of type/attribute mismatch happens ... and is expected to only be an argument to LLVM-level operations |
I see. However the mixing sequence is just lowered from like %0 = arith.constant dense<0.> : vector<192xf8E4M3FN>
our defined op %0, %args in one pass That means, even if we do %4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8> The same problem will expose. We have tried locally. Actually under current mechanism any subsequent operations on it will expose this problem. |
I think the lowering and dialect mix-up is orthogonal to the actual problem here. For me, the question is: why the attribute type is not being converted when converting an
instead of:
? |
There are cases where there is not really a matching attribute type for the result type of a constant. For example, a constant that defines an llvm.array uses an attribute of tensor type. For the 8-bit floats the mismatch with the i8 result type is intended. In many other cases there may not be a specific reason and the verifier could possibly be further strengthened. @newling landed an improvement today (#148975). |
Ok, I can see the use cases now. Thanks!
Given that the type mismatch in |
That is, I'm claiming that I'm willing to be convinced that I'm wrong though |
My preference would be to make the element type change illegal in the constant op, i.e. that
is illegal (along the lines of this PR). One downside would be the textual form numerical values would no longer make sense. I don't have the context to know if this is a feasible restriction to place on the llvm constant op for all users.
This suggests that the problem isn't specific to ShapeCast::fold, but could be a problem for any fold operation. Is that correct @MengmSun ?
@krzysz00 it is the folder. The original solution in #147691 actually does a good job -- it uses the element type of the original attribute instead of the result. But maybe not scalable to all fold methods. In my opinion, if the API for reshaping the attribute was rather
instead of
As outlined in #149947 then this issue would not have arisen in #147691 |
Ok, higher-level question: Why do you need the lowering to your op to happen alongside the to-llvm rewrites? Can't you lower to shape_cast before you lower to LLVM? |
I agree. But I think it should be changed to use the result type instead of the original attribute element type because the type conversion is intended. @newling
Good point. Anyway after looking at these comments I think we still should solve this problem in #147691 by updating the |
Good point. We'll think about this deeply. |
@gysit , could you share a code example? I'm having difficult time visualising that. |
Sure this is an example for the array case:
For string we use a string attribute though:
|
Thanks! Diving into this a bit deeper, it sounds like it would be ok to require the numeric element types to match.
Here, the input has non-numeric element type, so that could be an exception. Does this make sense? |
How do we represent the constant attribute for a type that is supported in MLIR but is not in LLVM (e.g., f8E4M3FN)? |
I think there is the floating point exception that @dcaballe mentions. Apart from that, enforcing matching element types seems reasonable (however, I would not claim I understand all use cases). |
The way those unsupported float types work is that we make the result type of the |
We could probably change the behavior for these unsupported floats to do the bitcasts during convert-to-llvm, but I've got the rough sense that this has a readability impact ... though I wouldn't be too opposed to that change (which is, IIRC, this PR) |
1d5572a
to
38b4227
Compare
I have updated my local changes to this branch. 7 UTs failed, and I have attached my local log below ********************
Testing: 0.. 10.. 20.. 30.. 40.. 50.. 60.. 70.. 80.. 90..
********************
Failed Tests (7):
MLIR :: Conversion/FuncToLLVM/func-to-llvm.mlir
MLIR :: Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
MLIR :: Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
MLIR :: Conversion/GPUToROCDL/gpu-to-rocdl.mlir
MLIR :: Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
Testing Time: 57.32s
Total Discovered Tests: 3332
Unsupported : 458 (13.75%)
Passed : 2866 (86.01%)
Expectedly Failed: 1 (0.03%)
Failed : 7 (0.21%)
These all about the conversion between index->i64. I'm confused because it looks like many places intend to treat @dcaballe @banach-space thx. |
Yes, What I can see in the logs makes sense:
|
If you're fixing the element type, you'll want to ensure the |
As I described in [MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant, we have a case as below after
convert-to-llvm
.Our next pass is
Canonocalizer
. And after #133988 moved the canonicalizer to a folder merged seve months ago we met the problem in theCanonicalizer
pass:That's because
llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
lowered fromarith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>
. The element typef8E4M3FN
of the result type is converted toi8
withtypeConverter
. However, the element type of the dense attributes has not been converted. And the target attributes kept the same and passed to the new replaced opllvm.mlir.constant
. Then our problem exposed.Before we try to just WAR in
ShapeCastOp::fold()
. This can solve our problem. However as @dcaballe and @banach-space pointed out it's better to solve problems on root instead of maintaining other incorrect code. Theoretically, the target attributes inLLVM::detail::oneToOneRewrite()
maybe the same as source attributes as the current implementation, but not for all cases.So in this MR, we tried to fix element type of the dense attributes in target attributes to be consistent with result type. In this fix, the
arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>
will be converted tollvm.mlir.constant(dense<0> : vector<192xi8>) : vector<192xi8>
. It will not cause any accuracy loss of the dense value and as my UT shows it just reinterprets.