@@ -110,16 +110,34 @@ DenseElementsAttr applyElementWise(
110110 // We already know the amount of values we will insert, reserve space for
111111 // all of them to avoid dynamic resizing
112112 transformedValues.reserve (toTransform.getNumElements ());
113- for (auto val : toTransform.getValues <SrcValType>()) {
114- auto transformedVal = toApply (val, targetType);
115- transformedValues.push_back (transformedVal);
113+ if constexpr (std::is_same_v<SrcValType, APSInt>) {
114+ for (auto val : toTransform.getValues <APInt>()) {
115+ auto transformedVal =
116+ toApply (APSInt (val, toTransform.getElementType ().isUnsignedInteger ()),
117+ targetType);
118+ transformedValues.push_back (transformedVal);
119+ }
120+ } else {
121+ for (auto val : toTransform.getValues <SrcValType>()) {
122+ auto transformedVal = toApply (val, targetType);
123+ transformedValues.push_back (transformedVal);
124+ }
116125 }
117126
118127 // Make sure that the output tensor has the expected output type
119128 auto inShape = toTransform.getType ();
120129 auto outTy = inShape.cloneWith ({}, targetType);
121130
122- return DenseElementsAttr::get (outTy, transformedValues);
131+ if constexpr (std::is_same_v<TargetValType, APSInt>) {
132+ SmallVector<APInt> transformedValuesAPInt;
133+ transformedValuesAPInt.reserve (transformedValues.size ());
134+ for (APSInt val : transformedValues) {
135+ transformedValuesAPInt.emplace_back (val);
136+ }
137+ return DenseElementsAttr::get (outTy, transformedValuesAPInt);
138+ } else {
139+ return DenseElementsAttr::get (outTy, transformedValues);
140+ }
123141}
124142
125143template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
@@ -881,10 +899,10 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
881899
882900 using TosaFoldConstantBase::TosaFoldConstantBase;
883901
884- static APFloat convertIntToFloat (const APInt &toConvert,
902+ static APFloat convertIntToFloat (const APSInt &toConvert,
885903 FloatType targetType) {
886904 APFloat res (targetType.getFloatSemantics ());
887- res.convertFromAPInt (toConvert, true /* isSigned */ , tosaRoundingMode);
905+ res.convertFromAPInt (toConvert, toConvert. isSigned () , tosaRoundingMode);
888906 return res;
889907 }
890908
@@ -928,15 +946,14 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
928946 return converted;
929947 }
930948
931- static APInt convertIntToInt (const APInt &toConvert, IntegerType targetType) {
949+ static APSInt convertIntToInt (const APSInt &toConvert,
950+ IntegerType targetType) {
932951 // Make sure to properly translate booleans
933952 if (targetType.getWidth () == 1 ) {
934- return toConvert.isZero () ? APInt::getZero (1 ) : APInt::getAllOnes (1 );
935- }
936- if (targetType.isUnsigned ()) {
937- return toConvert.zextOrTrunc (targetType.getIntOrFloatBitWidth ());
953+ return APSInt (toConvert.isZero () ? APInt::getZero (1 )
954+ : APInt::getAllOnes (1 ));
938955 }
939- return toConvert.sextOrTrunc (targetType.getIntOrFloatBitWidth ());
956+ return toConvert.extOrTrunc (targetType.getIntOrFloatBitWidth ());
940957 }
941958
942959 static void warnAboutNaNToIntCast (DenseElementsAttr elements, CastOp location,
@@ -981,11 +998,11 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
981998 warnAboutNaNToIntCast (elements, tosaCast, rewriter);
982999
9831000 // Only fold splat tensors and those used only once to avoid duplicating
984- // them.
1001+ // them and increasing memory consumption .
9851002 if (!inputTensor.hasOneUse () && !isa<SplatElementsAttr>(elements)) {
986- return rewriter.notifyMatchFailure (tosaCast,
987- " Currently, casts will only be folded "
988- " if its input only has a single user" );
1003+ return rewriter.notifyMatchFailure (
1004+ tosaCast, " Currently, casts will only be folded "
1005+ " if its input only has a single user or is a splat value. " );
9891006 }
9901007
9911008 // Report a match failure for unexpected types
@@ -994,28 +1011,25 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
9941011 tosaCast, " Only casts from/to int/float are supported." );
9951012 }
9961013
997- auto isUnsigned = [](Type toCheck) {
998- return isa<IntegerType>(toCheck) &&
999- cast<IntegerType>(toCheck).isUnsigned ();
1000- };
1001- auto typesToCheck = {toType, fromType};
1002- if (llvm::any_of (typesToCheck, isUnsigned)) {
1014+ // TOSA spec does not allow casts from/to unsigned, but we partially do, to
1015+ // enable the folding of lowered qdq nodes
1016+ if (isa<FloatType>(fromType) && isa<IntegerType>(toType) &&
1017+ cast<IntegerType>(toType).isUnsigned ()) {
10031018 // TOSA casts currently don't support unsigned integers.
1004- // To support them by here, one could use APSInt instead of APInts,
1005- // however, this causes trouble with `getValues` which does not support
1006- // APSInts currently.
1019+ // Casting float to unsigned int would need a decision about how to handle
1020+ // negative floats
10071021 return rewriter.notifyMatchFailure (
1008- tosaCast, " Cast folding from/to unsigned integers is not supported." );
1022+ tosaCast,
1023+ " Cast folding from float to unsigned integers is not supported." );
10091024 }
1010-
10111025 DenseElementsAttr res;
10121026 if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
10131027 if (isa<FloatType>(fromType)) {
10141028 res = applyElementWise<APFloat, APInt, IntegerType>(
10151029 elements, &convertFloatToInt, intOutTy);
10161030 } else {
10171031 assert (isa<IntegerType>(fromType));
1018- res = applyElementWise<APInt, APInt , IntegerType>(
1032+ res = applyElementWise<APSInt, APSInt , IntegerType>(
10191033 elements, &convertIntToInt, intOutTy);
10201034 }
10211035 } else {
@@ -1026,7 +1040,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
10261040 elements, &convertFloatToFloat, floatOutTy);
10271041 } else {
10281042 assert (isa<IntegerType>(fromType));
1029- res = applyElementWise<APInt , APFloat, FloatType>(
1043+ res = applyElementWise<APSInt , APFloat, FloatType>(
10301044 elements, &convertIntToFloat, floatOutTy);
10311045 }
10321046 }
0 commit comments