diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 2568044f1fd32..2c02db4b0da16 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -330,10 +330,75 @@ LogicalResult LLVM::detail::oneToOneRewrite( return failure(); } + // Convert attribute element types to match the converted result types. + // This ensures that attributes like + // dense<0.0> : vector<4xf8E4M3FN> become + // dense<0> : vector<4xi8> + // when the result type is converted to i8. + SmallVector convertedAttrs; + for (auto attr : targetAttrs) { + if (auto floatAttr = dyn_cast(attr.getValue())) { + auto convertedElementType = + typeConverter.convertType(floatAttr.getType()); + if (convertedElementType != floatAttr.getType()) { + // Currently, only 1-byte or sub-byte float types will be converted and + // converted to integer types. + convertedAttrs.emplace_back( + attr.getName(), + IntegerAttr::get(convertedElementType, + floatAttr.getValue().bitcastToAPInt())); + } else { + convertedAttrs.emplace_back(attr); + } + } else if (auto intAttr = dyn_cast(attr.getValue())) { + auto convertedElementType = typeConverter.convertType(intAttr.getType()); + if (convertedElementType != intAttr.getType()) { + convertedAttrs.emplace_back( + attr.getName(), + IntegerAttr::get(convertedElementType, intAttr.getValue())); + } else { + convertedAttrs.emplace_back(attr); + } + } else if (auto denseAttr = dyn_cast(attr.getValue())) { + if (auto shapedType = dyn_cast(denseAttr.getType())) { + auto convertedElementType = + typeConverter.convertType(shapedType.getElementType()); + if (convertedElementType != shapedType.getElementType()) { + ShapedType convertedShapedType = + shapedType.cloneWith(std::nullopt, convertedElementType); + convertedAttrs.emplace_back( + attr.getName(), DenseElementsAttr::getFromRawBuffer( + convertedShapedType, denseAttr.getRawData())); + } else { + convertedAttrs.emplace_back(attr); + } + } + } else if (auto sparseAttr = + dyn_cast(attr.getValue())) { + if (auto shapedType = dyn_cast(sparseAttr.getType())) { + auto convertedElementType = + typeConverter.convertType(shapedType.getElementType()); + if (convertedElementType != shapedType.getElementType()) { + ShapedType convertedShapedType = + shapedType.cloneWith(std::nullopt, convertedElementType); + convertedAttrs.emplace_back( + attr.getName(), + SparseElementsAttr::get( + convertedShapedType, sparseAttr.getIndices(), + sparseAttr.getValues().bitcast(convertedElementType))); + } else { + convertedAttrs.emplace_back(attr); + } + } + } else { + convertedAttrs.push_back(attr); + } + } + // Create the operation through state since we don't know its C++ type. Operation *newOp = rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, targetAttrs); + resultTypes, convertedAttrs); setNativeProperties(newOp, overflowFlags); diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 83bdbe1f67118..299cc32351bdb 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () { // CHECK-LABEL: @index_vector func.func @index_vector(%arg0: vector<4xindex>) { - // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64> + // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64> %0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> // CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64> %1 = arith.addi %arg0, %0 : vector<4xindex> @@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) { // ----- +// CHECK-LABEL: @f8E4M3FN_vector +func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8> + %0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN> + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8> + %1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN> + // CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8> + %2 = arith.addf %0, %1 : vector<4xf8E4M3FN> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN> + // CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN> + func.return %2 : vector<4xf8E4M3FN> +} + +// ----- + // CHECK-LABEL: @bitcast_1d func.func @bitcast_1d(%arg0: vector<2xf32>) { // CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>