Skip to content

[mlir][arith.constant]Fix element type of the dense attributes in target attributes to be consistent with result type in LLVM::detail::oneToOneRewrite() #149787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

MengmSun
Copy link

As I described in [MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant, we have a case as below after convert-to-llvm.

...
%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>
%10 = vector.extract %8[0] : vector<192xi8> from vector<1x192xi8>
...

Our next pass is Canonocalizer. And after #133988 moved the canonicalizer to a folder merged seve months ago we met the problem in the Canonicalizer pass:

mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() && "expected the same element type"' failed.

That's because llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8> lowered from arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>. The element type f8E4M3FN of the result type is converted to i8 with typeConverter . However, the element type of the dense attributes has not been converted. And the target attributes kept the same and passed to the new replaced op llvm.mlir.constant. Then our problem exposed.

Before we try to just WAR in ShapeCastOp::fold(). This can solve our problem. However as @dcaballe and @banach-space pointed out it's better to solve problems on root instead of maintaining other incorrect code. Theoretically, the target attributes in LLVM::detail::oneToOneRewrite() maybe the same as source attributes as the current implementation, but not for all cases.

So in this MR, we tried to fix element type of the dense attributes in target attributes to be consistent with result type. In this fix, the arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN> will be converted to llvm.mlir.constant(dense<0> : vector<192xi8>) : vector<192xi8>. It will not cause any accuracy loss of the dense value and as my UT shows it just reinterprets.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Mengmeng Sun (MengmSun)

Changes

As I described in [MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant, we have a case as below after convert-to-llvm.

...
%4 = llvm.mlir.constant(dense&lt;0.000000e+00&gt; : vector&lt;192xf8E4M3FN&gt;) : vector&lt;192xi8&gt;
%8 = vector.shape_cast %4 : vector&lt;192xi8&gt; to vector&lt;1x192xi8&gt;
%10 = vector.extract %8[0] : vector&lt;192xi8&gt; from vector&lt;1x192xi8&gt;
...

Our next pass is Canonocalizer. And after #133988 moved the canonicalizer to a folder merged seve months ago we met the problem in the Canonicalizer pass:

mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() &amp;&amp; "expected the same element type"' failed.

That's because llvm.mlir.constant(dense&lt;0.000000e+00&gt; : vector&lt;192xf8E4M3FN&gt;) : vector&lt;192xi8&gt; lowered from arith.constant dense&lt;0.000000e+00&gt; : vector&lt;192xf8E4M3FN&gt;. The element type f8E4M3FN of the result type is converted to i8 with typeConverter . However, the element type of the dense attributes has not been converted. And the target attributes kept the same and passed to the new replaced op llvm.mlir.constant. Then our problem exposed.

Before we try to just WAR in ShapeCastOp::fold(). This can solve our problem. However as @dcaballe and @banach-space pointed out it's better to solve problems on root instead of maintaining other incorrect code. Theoretically, the target attributes in LLVM::detail::oneToOneRewrite() maybe the same as source attributes as the current implementation, but not for all cases.

So in this MR, we tried to fix element type of the dense attributes in target attributes to be consistent with result type. In this fix, the arith.constant dense&lt;0.000000e+00&gt; : vector&lt;192xf8E4M3FN&gt; will be converted to llvm.mlir.constant(dense&lt;0&gt; : vector&lt;192xi8&gt;) : vector&lt;192xi8&gt;. It will not cause any accuracy loss of the dense value and as my UT shows it just reinterprets.


Full diff: https://github.com/llvm/llvm-project/pull/149787.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+26-1)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+16-1)
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7e10b8c..329703e4f054d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
       return failure();
   }
 
+  // If the targetAttrs contains DenseElementsAttr,
+  // and the element type of the DenseElementsAttr and result type is
+  // inconsistent after the conversion of result types, we need to convert the
+  // element type of the DenseElementsAttr to the target type by creating a new
+  // DenseElementsAttr with the converted element type, and use the new
+  // DenseElementsAttr to replace the old one in the targetAttrs
+  SmallVector<NamedAttribute> convertedAttrs;
+  for (auto attr : targetAttrs) {
+    if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
+      VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
+      if (vectorType) {
+        auto convertedElementType =
+            typeConverter.convertType(vectorType.getElementType());
+        VectorType convertedVectorType =
+            VectorType::get(vectorType.getShape(), convertedElementType,
+                            vectorType.getScalableDims());
+        convertedAttrs.emplace_back(
+            attr.getName(), DenseElementsAttr::getFromRawBuffer(
+                                convertedVectorType, denseAttr.getRawData()));
+      }
+    } else {
+      convertedAttrs.push_back(attr);
+    }
+  }
+
   // Create the operation through state since we don't know its C++ type.
   Operation *newOp =
       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                      resultTypes, targetAttrs);
+                      resultTypes, convertedAttrs);
 
   setNativeProperties(newOp, overflowFlags);
 
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 83bdbe1f67118..299cc32351bdb 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
 
 // CHECK-LABEL: @index_vector
 func.func @index_vector(%arg0: vector<4xindex>) {
-  // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
+  // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
   %0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
   // CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
   %1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
 
 // -----
 
+// CHECK-LABEL: @f8E4M3FN_vector
+func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
+  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
+  %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
+  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+  %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
+  // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
+  %2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
+  // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
+  func.return %2 : vector<4xf8E4M3FN>
+}
+
+// -----
+
 // CHECK-LABEL: @bitcast_1d
 func.func @bitcast_1d(%arg0: vector<2xf32>) {
   // CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general comment, I don't think mixing llvm.mlir.constant and operations like vector.shape_cast makes sense, and since llvm.mlir.constant allows for a type/attribute mismatch (to my knowledge), there might not be anything to fix here.

convertedVectorType, denseAttr.getRawData()));
}
} else {
convertedAttrs.push_back(attr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also be done for scalar attributes? Ex. a index constant becoming an i64 constant?

@MengmSun
Copy link
Author

MengmSun commented Jul 21, 2025

As a general comment, I don't think mixing llvm.mlir.constant and operations like vector.shape_cast makes sense, and since llvm.mlir.constant allows for a type/attribute mismatch (to my knowledge), there might not be anything to fix here.

I don't think mixing is illegal. vector.shape_cast is lowered from our defined ops.
In the previous #133988 the point is if llvm.mlir.constant allows for a type/attribute mismatch. If it allows, in my understanding, vector.shape_cast should consider this case instead of thinking the element type of the dense attributes is consistent with the result type here by default..

// shape_cast(constant) -> constant
  if (auto splatAttr =
          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
    return splatAttr.reshape(getType());

So I think modifying this part is more reasonable. It's not maintaining more codes for incorrect codes. Instead, it categorizes the possible cases that may be encountered with shape_cast(constant)->constant.

@krzysz00
Copy link
Contributor

Well, no, the point I'm making is that LLVM lowerings should be done in one shot without intervening canonicalizations, because llvm.mlir.constant is one of the few cases where this sort of type/attribute mismatch happens ... and is expected to only be an argument to LLVM-level operations

@MengmSun
Copy link
Author

MengmSun commented Jul 21, 2025

Well, no, the point I'm making is that LLVM lowerings should be done in one shot without intervening canonicalizations, because llvm.mlir.constant is one of the few cases where this sort of type/attribute mismatch happens ... and is expected to only be an argument to LLVM-level operations

I see. However the mixing sequence is just lowered from like

%0 = arith.constant dense<0.> : vector<192xf8E4M3FN>
our defined op %0, %args

in one pass -convert-to-llvm. Yeah we emit vector.shape_cast by lowering our defined op in the convert-to-llvm pass. However even if we want to lower vector.shape_cast to llvm dialect before canonicalizer, this problem will still expose. It will call ShapeCastOp::fold() in the process of legalizing before pattern match and converting.

That means, even if we do convert-to-llvm for

%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>

The same problem will expose. We have tried locally. Actually under current mechanism any subsequent operations on it will expose this problem.

@dcaballe dcaballe requested review from newling and joker-eph July 21, 2025 17:39
@dcaballe
Copy link
Contributor

I think the lowering and dialect mix-up is orthogonal to the actual problem here. For me, the question is: why the attribute type is not being converted when converting an arith.constant to llvm.mlir.constant. That is, why do we allow something like:

%4 = llvm.mlir.constant(dense<0> : index) : i32

instead of:

%4 = llvm.mlir.constant(dense<0> : i32) : i32

?
Are there scenarios where the attribute-return type mismatch is needed? Is this something that was needed in the past but it's no longer needed?

@gysit
Copy link
Contributor

gysit commented Jul 21, 2025

Are there scenarios where the attribute-return type mismatch is needed? Is this something that was needed in the past but it's no longer needed?

There are cases where there is not really a matching attribute type for the result type of a constant. For example, a constant that defines an llvm.array uses an attribute of tensor type. For the 8-bit floats the mismatch with the i8 result type is intended. In many other cases there may not be a specific reason and the verifier could possibly be further strengthened. @newling landed an improvement today (#148975).

@dcaballe
Copy link
Contributor

Ok, I can see the use cases now. Thanks!

%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>

Given that the type mismatch in llvm.mlir.constant is required, could we constrain the vector.shape_cast canonicalization pattern so that it only matches against arith.constant and not llvm.mlir.constant? I think supporting llvm.mlir.constant is not worth the complexity and could be seen as a layering violation, as mentioned before. It has been accidentally working so far but it looks like a bug more than a feature to me at this point...

@krzysz00
Copy link
Contributor

krzysz00 commented Jul 21, 2025

  1. If the vector::ShapeCast folder (not canonicalizer) is hitting this bug, then we need to add guards to the folder to catch this sort of thing
  2. The canonicalization patterns shouldn't be mixed in with convert-to-llvm. You should do the lowering from your op to vector.shape_cast and then do a convert-to-llvm, IMO.

That is, I'm claiming that vector.shape_cast(llvm.mlir.constant) should only transiently appear during lowering to LLVM dialects, and if it's showing up long enough to see the canonicalization patterns there's a layering violation

I'm willing to be convinced that I'm wrong though

@newling
Copy link
Contributor

newling commented Jul 21, 2025

My preference would be to make the element type change illegal in the constant op, i.e. that

%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>

is illegal (along the lines of this PR). One downside would be the textual form numerical values would no longer make sense. I don't have the context to know if this is a feasible restriction to place on the llvm constant op for all users.

Actually under current mechanism any subsequent operations on it will expose this problem.

This suggests that the problem isn't specific to ShapeCast::fold, but could be a problem for any fold operation. Is that correct @MengmSun ?

If the vector::ShapeCast folder (not canonicalizer) is hitting this bug, then we need to add guards to the folder to catch this sort of thing

@krzysz00 it is the folder. The original solution in #147691 actually does a good job -- it uses the element type of the original attribute instead of the result. But maybe not scalable to all fold methods. In my opinion, if the API for reshaping the attribute was rather

 DenseElementsAttr::reshape(ArrayRef<int64_t> newShape)

instead of

 DenseElementsAttr::reshape(ShapedType newType)

As outlined in #149947 then this issue would not have arisen in #147691

@krzysz00
Copy link
Contributor

Ok, higher-level question: Why do you need the lowering to your op to happen alongside the to-llvm rewrites? Can't you lower to shape_cast before you lower to LLVM?

@MengmSun
Copy link
Author

MengmSun commented Jul 22, 2025

The original solution in #147691actually does a good job -- it uses the element type of the original attribute instead of the result.

I agree. But I think it should be changed to use the result type instead of the original attribute element type because the type conversion is intended. @newling

if the API for reshaping the attribute was rather

DenseElementsAttr::reshape(ArrayRef<int64_t> newShape)
instead of

DenseElementsAttr::reshape(ShapedType newType)
As outlined in #149947 then this issue would not have arisen in #147691

Good point.

Anyway after looking at these comments I think we still should solve this problem in #147691 by updating the ShapeCastOp::fold(). @dcaballe @banach-space

@MengmSun
Copy link
Author

Ok, higher-level question: Why do you need the lowering to your op to happen alongside the to-llvm rewrites? Can't you lower to shape_cast before you lower to LLVM?

Good point. We'll think about this deeply.

@banach-space
Copy link
Contributor

For example, a constant that defines an llvm.array uses an attribute of tensor type. For the 8-bit floats the mismatch with the i8 result type is intended.

@gysit , could you share a code example? I'm having difficult time visualising that.

@gysit
Copy link
Contributor

gysit commented Jul 23, 2025

@gysit , could you share a code example? I'm having difficult time visualising that.

Sure this is an example for the array case:

llvm.mlir.constant(dense<-8.900000e+01> : tensor<2xf64>) : !llvm.array<2 x f64>

For string we use a string attribute though:

llvm.mlir.constant("hello") : !llvm.array<5 x i8>

@banach-space
Copy link
Contributor

@gysit , could you share a code example? I'm having difficult time visualising that.

Sure this is an example for the array case:

llvm.mlir.constant(dense<-8.900000e+01> : tensor<2xf64>) : !llvm.array<2 x f64>

Thanks! Diving into this a bit deeper, it sounds like it would be ok to require the numeric element types to match.

For string we use a string attribute though:

llvm.mlir.constant("hello") : !llvm.array<5 x i8>

Here, the input has non-numeric element type, so that could be an exception.

Does this make sense?

@dcaballe
Copy link
Contributor

How do we represent the constant attribute for a type that is supported in MLIR but is not in LLVM (e.g., f8E4M3FN)?

@gysit
Copy link
Contributor

gysit commented Jul 23, 2025

Thanks! Diving into this a bit deeper, it sounds like it would be ok to require the numeric element types to match.

I think there is the floating point exception that @dcaballe mentions. Apart from that, enforcing matching element types seems reasonable (however, I would not claim I understand all use cases).

@krzysz00
Copy link
Contributor

The way those unsupported float types work is that we make the result type of the llvm.mlir.constant the same-width integer but keep the attribute as float-typed. Then, during translation, we'll do APFloat-to-APInt bitcasting to generate the LLVM-level constant

@krzysz00
Copy link
Contributor

We could probably change the behavior for these unsupported floats to do the bitcasts during convert-to-llvm, but I've got the rough sense that this has a readability impact ... though I wouldn't be too opposed to that change (which is, IIRC, this PR)

@MengmSun MengmSun force-pushed the mlir/fix/arith_constant_target_attr branch from 1d5572a to 38b4227 Compare July 30, 2025 06:54
@MengmSun
Copy link
Author

MengmSun commented Jul 30, 2025

I have updated my local changes to this branch.
However, I met some problems locally with make check-mlir which confused me(these failures will be exposed in this CI as well):

7 UTs failed, and I have attached my local log below

check.log

********************
Testing:  0.. 10.. 20.. 30.. 40.. 50.. 60.. 70.. 80.. 90..
********************
Failed Tests (7):
  MLIR :: Conversion/FuncToLLVM/func-to-llvm.mlir
  MLIR :: Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
  MLIR :: Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
  MLIR :: Conversion/GPUToROCDL/gpu-to-rocdl.mlir
  MLIR :: Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
  MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
  MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir


Testing Time: 57.32s

Total Discovered Tests: 3332
  Unsupported      :  458 (13.75%)
  Passed           : 2866 (86.01%)
  Expectedly Failed:    1 (0.03%)
  Failed           :    7 (0.21%)

These all about the conversion between index->i64. I'm confused because it looks like many places intend to treat index as i32. Is it intended? Should we treat index specifically(i.e. not convert index element type to i64 by type converter?)

@dcaballe @banach-space thx.

@dcaballe
Copy link
Contributor

These all about the conversion between index->i64. I'm confused because it looks like many places intend to treat index as i32. Is it intended? Should we treat index specifically(i.e. not convert index element type to i64 by type converter?)

Yes, index type can be converted to i64 or i32 (or something else, I guess). We get index size information from DLTI. Some patterns may also take an argument with this information, IIRM. So, I guess you have to just fix the tests...

What I can see in the logs makes sense:

  # | /home/gha/actions-runner/_work/llvm-project/llvm-project/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir:33:16: error: CHECK-NEXT: expected string not found in input
  # | // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : i64
  # |                ^
  # | <stdin>:8:21: note: scanning from here
  # |  ^bb1: // pred: ^bb0
  # |                     ^
  # | <stdin>:9:2: note: possible intended match here
  # |  %0 = llvm.mlir.constant(1 : i64) : i64

@krzysz00
Copy link
Contributor

If you're fixing the element type, you'll want to ensure the index attribute gets replaced by either an i64 or an i32 (or i16 perhaps if someone asks) - whatever the type index maps to

@nikic nikic removed request for a team, nikic, ayermolo, andykaylor and yozhu July 30, 2025 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants