Skip to content

Commit d4e6a3f

Browse files
author
MengmengSun
committed
Fix element type of target attributes in oneToOneRewrite when converting to llvm
1 parent 6193dd5 commit d4e6a3f

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
331331
return failure();
332332
}
333333

334+
// If the targetAttrs contains DenseElementsAttr,
335+
// and the element type of the DenseElementsAttr and result type is
336+
// inconsistent after the conversion of result types, we need to convert the
337+
// element type of the DenseElementsAttr to the target type by creating a new
338+
// DenseElementsAttr with the converted element type, and use the new
339+
// DenseElementsAttr to replace the old one in the targetAttrs
340+
SmallVector<NamedAttribute> convertedAttrs;
341+
for (auto attr : targetAttrs) {
342+
if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
343+
VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
344+
if (vectorType) {
345+
auto convertedElementType =
346+
typeConverter.convertType(vectorType.getElementType());
347+
VectorType convertedVectorType =
348+
VectorType::get(vectorType.getShape(), convertedElementType,
349+
vectorType.getScalableDims());
350+
convertedAttrs.emplace_back(
351+
attr.getName(), DenseElementsAttr::getFromRawBuffer(
352+
convertedVectorType, denseAttr.getRawData()));
353+
}
354+
} else {
355+
convertedAttrs.push_back(attr);
356+
}
357+
}
358+
334359
// Create the operation through state since we don't know its C++ type.
335360
Operation *newOp =
336361
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
337-
resultTypes, targetAttrs);
362+
resultTypes, convertedAttrs);
338363

339364
setNativeProperties(newOp, overflowFlags);
340365

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
428428

429429
// CHECK-LABEL: @index_vector
430430
func.func @index_vector(%arg0: vector<4xindex>) {
431-
// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
431+
// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
432432
%0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
433433
// CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
434434
%1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
437437

438438
// -----
439439

440+
// CHECK-LABEL: @f8E4M3FN_vector
441+
func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
442+
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
443+
%0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
444+
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
445+
%1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
446+
// CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
447+
%2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
448+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
449+
// CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
450+
func.return %2 : vector<4xf8E4M3FN>
451+
}
452+
453+
// -----
454+
440455
// CHECK-LABEL: @bitcast_1d
441456
func.func @bitcast_1d(%arg0: vector<2xf32>) {
442457
// CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>

0 commit comments

Comments
 (0)