From d4e6a3ff33dc82e574f92b01e7613a007f63a7d9 Mon Sep 17 00:00:00 2001 From: MengmengSun Date: Mon, 21 Jul 2025 02:01:35 -0700 Subject: [PATCH 1/2] Fix element type of target attributes in oneToOneRewrite when converting to llvm --- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 27 ++++++++++++++++++- .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 17 +++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index c5f72f7e10b8c..329703e4f054d 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite( return failure(); } + // If the targetAttrs contains DenseElementsAttr, + // and the element type of the DenseElementsAttr and result type is + // inconsistent after the conversion of result types, we need to convert the + // element type of the DenseElementsAttr to the target type by creating a new + // DenseElementsAttr with the converted element type, and use the new + // DenseElementsAttr to replace the old one in the targetAttrs + SmallVector convertedAttrs; + for (auto attr : targetAttrs) { + if (auto denseAttr = dyn_cast(attr.getValue())) { + VectorType vectorType = dyn_cast(denseAttr.getType()); + if (vectorType) { + auto convertedElementType = + typeConverter.convertType(vectorType.getElementType()); + VectorType convertedVectorType = + VectorType::get(vectorType.getShape(), convertedElementType, + vectorType.getScalableDims()); + convertedAttrs.emplace_back( + attr.getName(), DenseElementsAttr::getFromRawBuffer( + convertedVectorType, denseAttr.getRawData())); + } + } else { + convertedAttrs.push_back(attr); + } + } + // Create the operation through state since we don't know its C++ type. Operation *newOp = rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, targetAttrs); + resultTypes, convertedAttrs); setNativeProperties(newOp, overflowFlags); diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 83bdbe1f67118..299cc32351bdb 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () { // CHECK-LABEL: @index_vector func.func @index_vector(%arg0: vector<4xindex>) { - // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64> + // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64> %0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> // CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64> %1 = arith.addi %arg0, %0 : vector<4xindex> @@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) { // ----- +// CHECK-LABEL: @f8E4M3FN_vector +func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8> + %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN> + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8> + %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN> + // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8> + %2 = arith.addf %0, %1 : vector<4xf8E4M3FN> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN> + // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN> + func.return %2 : vector<4xf8E4M3FN> +} + +// ----- + // CHECK-LABEL: @bitcast_1d func.func @bitcast_1d(%arg0: vector<2xf32>) { // CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32> From 38b4227db3dc450caf396c7429c67b677112d3df Mon Sep 17 00:00:00 2001 From: MengmengSun Date: Tue, 29 Jul 2025 23:53:18 -0700 Subject: [PATCH 2/2] Spread elements type conversion to all valid type attr --- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 72 +++++++++++++++++----- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 8546989fe8e2e..2c02db4b0da16 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -330,25 +330,65 @@ LogicalResult LLVM::detail::oneToOneRewrite( return failure(); } - // If the targetAttrs contains DenseElementsAttr, - // and the element type of the DenseElementsAttr and result type is - // inconsistent after the conversion of result types, we need to convert the - // element type of the DenseElementsAttr to the target type by creating a new - // DenseElementsAttr with the converted element type, and use the new - // DenseElementsAttr to replace the old one in the targetAttrs + // Convert attribute element types to match the converted result types. + // This ensures that attributes like + // dense<0.0> : vector<4xf8E4M3FN> become + // dense<0> : vector<4xi8> + // when the result type is converted to i8. SmallVector convertedAttrs; for (auto attr : targetAttrs) { - if (auto denseAttr = dyn_cast(attr.getValue())) { - VectorType vectorType = dyn_cast(denseAttr.getType()); - if (vectorType) { - auto convertedElementType = - typeConverter.convertType(vectorType.getElementType()); - VectorType convertedVectorType = - VectorType::get(vectorType.getShape(), convertedElementType, - vectorType.getScalableDims()); + if (auto floatAttr = dyn_cast(attr.getValue())) { + auto convertedElementType = + typeConverter.convertType(floatAttr.getType()); + if (convertedElementType != floatAttr.getType()) { + // Currently, only 1-byte or sub-byte float types will be converted and + // converted to integer types. + convertedAttrs.emplace_back( + attr.getName(), + IntegerAttr::get(convertedElementType, + floatAttr.getValue().bitcastToAPInt())); + } else { + convertedAttrs.emplace_back(attr); + } + } else if (auto intAttr = dyn_cast(attr.getValue())) { + auto convertedElementType = typeConverter.convertType(intAttr.getType()); + if (convertedElementType != intAttr.getType()) { convertedAttrs.emplace_back( - attr.getName(), DenseElementsAttr::getFromRawBuffer( - convertedVectorType, denseAttr.getRawData())); + attr.getName(), + IntegerAttr::get(convertedElementType, intAttr.getValue())); + } else { + convertedAttrs.emplace_back(attr); + } + } else if (auto denseAttr = dyn_cast(attr.getValue())) { + if (auto shapedType = dyn_cast(denseAttr.getType())) { + auto convertedElementType = + typeConverter.convertType(shapedType.getElementType()); + if (convertedElementType != shapedType.getElementType()) { + ShapedType convertedShapedType = + shapedType.cloneWith(std::nullopt, convertedElementType); + convertedAttrs.emplace_back( + attr.getName(), DenseElementsAttr::getFromRawBuffer( + convertedShapedType, denseAttr.getRawData())); + } else { + convertedAttrs.emplace_back(attr); + } + } + } else if (auto sparseAttr = + dyn_cast(attr.getValue())) { + if (auto shapedType = dyn_cast(sparseAttr.getType())) { + auto convertedElementType = + typeConverter.convertType(shapedType.getElementType()); + if (convertedElementType != shapedType.getElementType()) { + ShapedType convertedShapedType = + shapedType.cloneWith(std::nullopt, convertedElementType); + convertedAttrs.emplace_back( + attr.getName(), + SparseElementsAttr::get( + convertedShapedType, sparseAttr.getIndices(), + sparseAttr.getValues().bitcast(convertedElementType))); + } else { + convertedAttrs.emplace_back(attr); + } } } else { convertedAttrs.push_back(attr);