Skip to content

Commit 38b4227

Browse files
author
MengmengSun
committed
Spread elements type conversion to all valid type attr
1 parent 891ecaa commit 38b4227

File tree

1 file changed

+56
-16
lines changed

1 file changed

+56
-16
lines changed

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -330,25 +330,65 @@ LogicalResult LLVM::detail::oneToOneRewrite(
330330
return failure();
331331
}
332332

333-
// If the targetAttrs contains DenseElementsAttr,
334-
// and the element type of the DenseElementsAttr and result type is
335-
// inconsistent after the conversion of result types, we need to convert the
336-
// element type of the DenseElementsAttr to the target type by creating a new
337-
// DenseElementsAttr with the converted element type, and use the new
338-
// DenseElementsAttr to replace the old one in the targetAttrs
333+
// Convert attribute element types to match the converted result types.
334+
// This ensures that attributes like
335+
// dense<0.0> : vector<4xf8E4M3FN> become
336+
// dense<0> : vector<4xi8>
337+
// when the result type is converted to i8.
339338
SmallVector<NamedAttribute> convertedAttrs;
340339
for (auto attr : targetAttrs) {
341-
if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
342-
VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
343-
if (vectorType) {
344-
auto convertedElementType =
345-
typeConverter.convertType(vectorType.getElementType());
346-
VectorType convertedVectorType =
347-
VectorType::get(vectorType.getShape(), convertedElementType,
348-
vectorType.getScalableDims());
340+
if (auto floatAttr = dyn_cast<FloatAttr>(attr.getValue())) {
341+
auto convertedElementType =
342+
typeConverter.convertType(floatAttr.getType());
343+
if (convertedElementType != floatAttr.getType()) {
344+
// Currently, only 1-byte or sub-byte float types will be converted and
345+
// converted to integer types.
346+
convertedAttrs.emplace_back(
347+
attr.getName(),
348+
IntegerAttr::get(convertedElementType,
349+
floatAttr.getValue().bitcastToAPInt()));
350+
} else {
351+
convertedAttrs.emplace_back(attr);
352+
}
353+
} else if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
354+
auto convertedElementType = typeConverter.convertType(intAttr.getType());
355+
if (convertedElementType != intAttr.getType()) {
349356
convertedAttrs.emplace_back(
350-
attr.getName(), DenseElementsAttr::getFromRawBuffer(
351-
convertedVectorType, denseAttr.getRawData()));
357+
attr.getName(),
358+
IntegerAttr::get(convertedElementType, intAttr.getValue()));
359+
} else {
360+
convertedAttrs.emplace_back(attr);
361+
}
362+
} else if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
363+
if (auto shapedType = dyn_cast<ShapedType>(denseAttr.getType())) {
364+
auto convertedElementType =
365+
typeConverter.convertType(shapedType.getElementType());
366+
if (convertedElementType != shapedType.getElementType()) {
367+
ShapedType convertedShapedType =
368+
shapedType.cloneWith(std::nullopt, convertedElementType);
369+
convertedAttrs.emplace_back(
370+
attr.getName(), DenseElementsAttr::getFromRawBuffer(
371+
convertedShapedType, denseAttr.getRawData()));
372+
} else {
373+
convertedAttrs.emplace_back(attr);
374+
}
375+
}
376+
} else if (auto sparseAttr =
377+
dyn_cast<SparseElementsAttr>(attr.getValue())) {
378+
if (auto shapedType = dyn_cast<ShapedType>(sparseAttr.getType())) {
379+
auto convertedElementType =
380+
typeConverter.convertType(shapedType.getElementType());
381+
if (convertedElementType != shapedType.getElementType()) {
382+
ShapedType convertedShapedType =
383+
shapedType.cloneWith(std::nullopt, convertedElementType);
384+
convertedAttrs.emplace_back(
385+
attr.getName(),
386+
SparseElementsAttr::get(
387+
convertedShapedType, sparseAttr.getIndices(),
388+
sparseAttr.getValues().bitcast(convertedElementType)));
389+
} else {
390+
convertedAttrs.emplace_back(attr);
391+
}
352392
}
353393
} else {
354394
convertedAttrs.push_back(attr);

0 commit comments

Comments
 (0)