@@ -330,25 +330,65 @@ LogicalResult LLVM::detail::oneToOneRewrite(
330
330
return failure ();
331
331
}
332
332
333
- // If the targetAttrs contains DenseElementsAttr,
334
- // and the element type of the DenseElementsAttr and result type is
335
- // inconsistent after the conversion of result types, we need to convert the
336
- // element type of the DenseElementsAttr to the target type by creating a new
337
- // DenseElementsAttr with the converted element type, and use the new
338
- // DenseElementsAttr to replace the old one in the targetAttrs
333
+ // Convert attribute element types to match the converted result types.
334
+ // This ensures that attributes like
335
+ // dense<0.0> : vector<4xf8E4M3FN> become
336
+ // dense<0> : vector<4xi8>
337
+ // when the result type is converted to i8.
339
338
SmallVector<NamedAttribute> convertedAttrs;
340
339
for (auto attr : targetAttrs) {
341
- if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue ())) {
342
- VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType ());
343
- if (vectorType) {
344
- auto convertedElementType =
345
- typeConverter.convertType (vectorType.getElementType ());
346
- VectorType convertedVectorType =
347
- VectorType::get (vectorType.getShape (), convertedElementType,
348
- vectorType.getScalableDims ());
340
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr.getValue ())) {
341
+ auto convertedElementType =
342
+ typeConverter.convertType (floatAttr.getType ());
343
+ if (convertedElementType != floatAttr.getType ()) {
344
+ // Currently, only 1-byte or sub-byte float types will be converted and
345
+ // converted to integer types.
346
+ convertedAttrs.emplace_back (
347
+ attr.getName (),
348
+ IntegerAttr::get (convertedElementType,
349
+ floatAttr.getValue ().bitcastToAPInt ()));
350
+ } else {
351
+ convertedAttrs.emplace_back (attr);
352
+ }
353
+ } else if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue ())) {
354
+ auto convertedElementType = typeConverter.convertType (intAttr.getType ());
355
+ if (convertedElementType != intAttr.getType ()) {
349
356
convertedAttrs.emplace_back (
350
- attr.getName (), DenseElementsAttr::getFromRawBuffer (
351
- convertedVectorType, denseAttr.getRawData ()));
357
+ attr.getName (),
358
+ IntegerAttr::get (convertedElementType, intAttr.getValue ()));
359
+ } else {
360
+ convertedAttrs.emplace_back (attr);
361
+ }
362
+ } else if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue ())) {
363
+ if (auto shapedType = dyn_cast<ShapedType>(denseAttr.getType ())) {
364
+ auto convertedElementType =
365
+ typeConverter.convertType (shapedType.getElementType ());
366
+ if (convertedElementType != shapedType.getElementType ()) {
367
+ ShapedType convertedShapedType =
368
+ shapedType.cloneWith (std::nullopt, convertedElementType);
369
+ convertedAttrs.emplace_back (
370
+ attr.getName (), DenseElementsAttr::getFromRawBuffer (
371
+ convertedShapedType, denseAttr.getRawData ()));
372
+ } else {
373
+ convertedAttrs.emplace_back (attr);
374
+ }
375
+ }
376
+ } else if (auto sparseAttr =
377
+ dyn_cast<SparseElementsAttr>(attr.getValue ())) {
378
+ if (auto shapedType = dyn_cast<ShapedType>(sparseAttr.getType ())) {
379
+ auto convertedElementType =
380
+ typeConverter.convertType (shapedType.getElementType ());
381
+ if (convertedElementType != shapedType.getElementType ()) {
382
+ ShapedType convertedShapedType =
383
+ shapedType.cloneWith (std::nullopt, convertedElementType);
384
+ convertedAttrs.emplace_back (
385
+ attr.getName (),
386
+ SparseElementsAttr::get (
387
+ convertedShapedType, sparseAttr.getIndices (),
388
+ sparseAttr.getValues ().bitcast (convertedElementType)));
389
+ } else {
390
+ convertedAttrs.emplace_back (attr);
391
+ }
352
392
}
353
393
} else {
354
394
convertedAttrs.push_back (attr);
0 commit comments