-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][tosa] Convert TOSA enumerations from StringBasedAttr
to Tosa_I32EnumAttr
#152856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Annu Singh (AnnuCode) ChangesFixes #152129 Patch is 157.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152856.diff 29 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..2aafed26a4e29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -207,7 +207,7 @@ def Tosa_VariableOpBuilder : OpBuilder<
}]>;
-// Wrapper over base I32EnumAttr to set common fields.
+ // Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
@@ -276,6 +276,7 @@ def Tosa_ProfileAttr
def Tosa_ProfileArrayAttr
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -381,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_BILINEAR : I32EnumAttrCase<"BILINEAR", 1>;
+def Tosa_RESIZE_NEAREST_NEIGHBOR : I32EnumAttrCase<"NEAREST_NEIGHBOR", 2>;
+
+def Tosa_ResizeTypeAttr
+ : Tosa_I32EnumAttr<"ResizeType", "Supported resize/upsampling strategies", "resize_type",
+ [Tosa_RESIZE_BILINEAR, Tosa_RESIZE_NEAREST_NEIGHBOR]>;
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationAttr
+ : Tosa_I32EnumAttr<"NanPropagation", "Supported NaN propagation strategies", "nan",
+ [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;
+
+def Tosa_ROUNDING_SINGLE_ROUND : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingTypeAttr
+ : Tosa_I32EnumAttr<"RoundingType", "Supported rounding modes", "rounding_type",
+ [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..fdb8f472dc060 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let arguments = (ins
Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
Tosa_Tensor:$input,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640dca6561..b21ce51eb03b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,11 +13,13 @@
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE
+
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
@@ -234,29 +236,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
- "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
- "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
- "Supported rounding modes">;
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725c7d805..4a027ccdadd61 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingType roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingType::DOUBLE_ROUND &&
+ roundingMode != RoundingType::SINGLE_ROUND) {
return failure();
}
@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingType roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingType::DOUBLE_ROUND &&
+ roundingMode != RoundingType::SINGLE_ROUND) {
return failure();
}
@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..35deeb43f51c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
auto nanMode = op.getNanMode();
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagation::PROPAGATE)
return result;
// Unordered comparison of NaN against itself will always return true.
@@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getStringAttr("SINGLE_ROUND"));
+ auto roundingAttr = RoundingTypeAttr::get(rewriter.getContext(),
+ RoundingType::SINGLE_ROUND);
+ auto result =
+ tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+ b, shiftConst, roundingAttr);
if (elementTy.isInteger(32))
return result;
@@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagation::PROPAGATE)
return result;
// In the case of "IGNORE" semantics materialize a comparison
@@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
- if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+ if (isa<FloatType>(elementTy) &&
+ op.getNanMode() == NanPropagation::IGNORE) {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getRoundingMode() == "INEXACT_ROUND")
+ if (op.getRoundingMode() == RoundingType::INEXACT_ROUND)
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
- if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// is ever true.
bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
+ op.getRoundingMode() == RoundingType::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ RoundingType roundingMode =
+ doubleRound ? RoundingType::DOUBLE_ROUND : RoundingType::SINGLE_ROUND;
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
- const bool isBilinear = op.getMode() == "BILINEAR";
+ const bool isBilinear = op.getMode() == ResizeType::BILINEAR;
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
@@ -1585,7 +1587,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
op, "tosa.resize is not a pure 1x1->1x1 image operation");
// TODO(suderman): These string values should be declared the TOSA dialect.
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeType::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1785,7 +1788,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeType::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1890,7 +1894,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}
- if (op.getMode() == "NEAREST_NEIGHBOR") {
+ if (op.getMode() == ResizeType::NEAREST_NEIGHBOR) {
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1930,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
- assert(op.getMode() == "BILINEAR");
+ assert(op.getMode() == ResizeType::BILINEAR);
auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
@@ -2291,7 +2295,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
Value predicate;
if (isa<FloatType>(inElementTy)) {
- if (argmaxOp.getNanMode() == "IGNORE") {
+ if (argmaxOp.getNanMode() == NanPropagation::IGNORE) {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca3768dd..0f738a848fcb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
dilationAttr);
rewriter.setInsertionPointAfter(op);
- StringRef nanMode = op.getNanMode();
+ NanPropagation nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
// NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
- if (nanMode == "IGNORE") {
+ if (nanMode == NanPropagation::IGNORE) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
rewriter, loc, rewriter.getI8IntegerAttr(30));
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
- auto scaled =
- tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
- shift, rewriter.getStringAttr("SINGLE_ROUND"))
- .getResult();
+ auto roundingAttr = RoundingTypeAttr::get(
+ rewriter.getContext(), RoundingType::SINGLE_ROUND);
+
+ auto scaled = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal,
+ multiplier, shift, roundingAttr)
+ .getResult();
// If we have quantization information we need to apply output
// zeropoint.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..8e27b267c83d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
- if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ if (opNanMode == NanPropagation::IGNORE &&
+ clampNanMode == NanPropagation::PROPAGATE)
return failure();
auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +637,14 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}
+ auto newMode =
+ (opNanMode != clampNanMode) ? tosa::NanPropagation::IGNORE : opNanMode;
+
+ auto newModeAttr = NanPropagationAttr::get(rewriter.getContext(), newMode);
+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
- rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
- : opNanMode));
+ newModeAttr);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..5c04874e494c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,13 +508,13...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Annu Singh (AnnuCode) ChangesFixes #152129 Patch is 157.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152856.diff 29 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..2aafed26a4e29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -207,7 +207,7 @@ def Tosa_VariableOpBuilder : OpBuilder<
}]>;
-// Wrapper over base I32EnumAttr to set common fields.
+ // Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
@@ -276,6 +276,7 @@ def Tosa_ProfileAttr
def Tosa_ProfileArrayAttr
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -381,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_BILINEAR : I32EnumAttrCase<"BILINEAR", 1>;
+def Tosa_RESIZE_NEAREST_NEIGHBOR : I32EnumAttrCase<"NEAREST_NEIGHBOR", 2>;
+
+def Tosa_ResizeTypeAttr
+ : Tosa_I32EnumAttr<"ResizeType", "Supported resize/upsampling strategies", "resize_type",
+ [Tosa_RESIZE_BILINEAR, Tosa_RESIZE_NEAREST_NEIGHBOR]>;
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationAttr
+ : Tosa_I32EnumAttr<"NanPropagation", "Supported NaN propagation strategies", "nan",
+ [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;
+
+def Tosa_ROUNDING_SINGLE_ROUND : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingTypeAttr
+ : Tosa_I32EnumAttr<"RoundingType", "Supported rounding modes", "rounding_type",
+ [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..fdb8f472dc060 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let arguments = (ins
Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
Tosa_Tensor:$input,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640dca6561..b21ce51eb03b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,11 +13,13 @@
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE
+
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
@@ -234,29 +236,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
- "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
- "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
- "Supported rounding modes">;
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725c7d805..4a027ccdadd61 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingType roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingType::DOUBLE_ROUND &&
+ roundingMode != RoundingType::SINGLE_ROUND) {
return failure();
}
@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingType roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingType::DOUBLE_ROUND &&
+ roundingMode != RoundingType::SINGLE_ROUND) {
return failure();
}
@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..35deeb43f51c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
auto nanMode = op.getNanMode();
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagation::PROPAGATE)
return result;
// Unordered comparison of NaN against itself will always return true.
@@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getStringAttr("SINGLE_ROUND"));
+ auto roundingAttr = RoundingTypeAttr::get(rewriter.getContext(),
+ RoundingType::SINGLE_ROUND);
+ auto result =
+ tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+ b, shiftConst, roundingAttr);
if (elementTy.isInteger(32))
return result;
@@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagation::PROPAGATE)
return result;
// In the case of "IGNORE" semantics materialize a comparison
@@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
- if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+ if (isa<FloatType>(elementTy) &&
+ op.getNanMode() == NanPropagation::IGNORE) {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getRoundingMode() == "INEXACT_ROUND")
+ if (op.getRoundingMode() == RoundingType::INEXACT_ROUND)
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
- if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// is ever true.
bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
+ op.getRoundingMode() == RoundingType::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ RoundingType roundingMode =
+ doubleRound ? RoundingType::DOUBLE_ROUND : RoundingType::SINGLE_ROUND;
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
- const bool isBilinear = op.getMode() == "BILINEAR";
+ const bool isBilinear = op.getMode() == ResizeType::BILINEAR;
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
@@ -1585,7 +1587,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
op, "tosa.resize is not a pure 1x1->1x1 image operation");
// TODO(suderman): These string values should be declared the TOSA dialect.
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeType::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1785,7 +1788,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeType::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1890,7 +1894,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}
- if (op.getMode() == "NEAREST_NEIGHBOR") {
+ if (op.getMode() == ResizeType::NEAREST_NEIGHBOR) {
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1930,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
- assert(op.getMode() == "BILINEAR");
+ assert(op.getMode() == ResizeType::BILINEAR);
auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
@@ -2291,7 +2295,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
Value predicate;
if (isa<FloatType>(inElementTy)) {
- if (argmaxOp.getNanMode() == "IGNORE") {
+ if (argmaxOp.getNanMode() == NanPropagation::IGNORE) {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca3768dd..0f738a848fcb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
dilationAttr);
rewriter.setInsertionPointAfter(op);
- StringRef nanMode = op.getNanMode();
+ NanPropagation nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
// NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
- if (nanMode == "IGNORE") {
+ if (nanMode == NanPropagation::IGNORE) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
rewriter, loc, rewriter.getI8IntegerAttr(30));
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
- auto scaled =
- tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
- shift, rewriter.getStringAttr("SINGLE_ROUND"))
- .getResult();
+ auto roundingAttr = RoundingTypeAttr::get(
+ rewriter.getContext(), RoundingType::SINGLE_ROUND);
+
+ auto scaled = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal,
+ multiplier, shift, roundingAttr)
+ .getResult();
// If we have quantization information we need to apply output
// zeropoint.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..8e27b267c83d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
- if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ if (opNanMode == NanPropagation::IGNORE &&
+ clampNanMode == NanPropagation::PROPAGATE)
return failure();
auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +637,14 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}
+ auto newMode =
+ (opNanMode != clampNanMode) ? tosa::NanPropagation::IGNORE : opNanMode;
+
+ auto newModeAttr = NanPropagationAttr::get(rewriter.getContext(), newMode);
+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
- rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
- : opNanMode));
+ newModeAttr);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..5c04874e494c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,13 +508,13...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this.
It would be better to update your commit message, as this change could be considered a breaking change for others. Could you please rewrite the commit message with more detailed context to support future development? For example, you might consider using a revised description from #152129
… enumerations. This PR replaces `StringBasedAttr` with `Tosa_I32EnumAttr` to represent Tosa enumerations as per the specification. The intent is to make the IR and C++ APIs more type-safe and prevent fragile string comparisons in passes. Enumerations rewritten are: - `Tosa_ResizeTypeAttr` - `Tosa_NanPropagationAttr` - `Tosa_RoundingTypeAttr` BREAKING CHANGE: This commit changes attribute assembly and the C++ API surface for the listed attributes. Code that previously used `StringAttr` for these fields must now be updated to use the new enum representation. In `.mlir` files, replace string literals with the enum assembly (e.g. `mode = #tosa.resize_type<BILINEAR>`). In C++, update call sites to either pass the generated enum (e.g. `::mlir::tosa::RoundingType::SINGLE_ROUND`) into builder overloads or construct the typed attribute with `tosa::RoundingTypeAttr::get(context, /*enum*/)` and pass that.
@psunn, thank you for the corrections. |
Thanks for the patch @AnnuCode! Overall the changes look great. Since this is a breaking change, we have a few dependencies which we would like to update at the same time so that we don't break builds. I'm working on this, but would like to wait till these changes are ready before merging this patch |
StringBasedAttr
TOSA enumerations to Tosa_I32EnumAttr
StringBasedAttr
to Tosa_I32EnumAttr
✅ With the latest revision this PR passed the C/C++ code formatter. |
Fixes #152129
Use
Tosa_I32EnumAttr
instead ofStringBasedAttr
to represent Tosa enumerations.This PR replaces
StringBasedAttr
withTosa_I32EnumAttr
to represent Tosa enumerations as per the specification. The intent is to make the IR and C++ APIs more type-safe and prevent fragile string comparisons in passes.Enumerations rewritten are:
Tosa_ResizeTypeAttr
Tosa_NanPropagationAttr
Tosa_RoundingTypeAttr
BREAKING CHANGE:
This commit changes attribute assembly and the C++ API surface for the listed attributes.
Code that previously used
StringAttr
for these fields must now be updated to use the new enum representation. In.mlir
files, replace string literals with the enum assembly (e.g.mode = #tosa.resize_type<BILINEAR>
). In C++, update call sites to either pass the generated enum (e.g.::mlir::tosa::RoundingType::SINGLE_ROUND
) into builder overloads or construct the typed attribute withtosa::RoundingTypeAttr::get(context, /*enum*/)
and pass that.