Skip to content

[mlir][tosa] Convert TOSA enumerations from StringBasedAttr to Tosa_I32EnumAttr #152856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

AnnuCode
Copy link
Contributor

@AnnuCode AnnuCode commented Aug 9, 2025

Fixes #152129

Use Tosa_I32EnumAttr instead of StringBasedAttr to represent Tosa enumerations.

This PR replaces StringBasedAttr with Tosa_I32EnumAttr to represent Tosa enumerations as per the specification. The intent is to make the IR and C++ APIs more type-safe and prevent fragile string comparisons in passes.

Enumerations rewritten are:

  • Tosa_ResizeTypeAttr
  • Tosa_NanPropagationAttr
  • Tosa_RoundingTypeAttr

BREAKING CHANGE:

This commit changes attribute assembly and the C++ API surface for the listed attributes.

Code that previously used StringAttr for these fields must now be updated to use the new enum representation. In .mlir files, replace string literals with the enum assembly (e.g. mode = #tosa.resize_type<BILINEAR>). In C++, update call sites to either pass the generated enum (e.g. ::mlir::tosa::RoundingType::SINGLE_ROUND) into builder overloads or construct the typed attribute with tosa::RoundingTypeAttr::get(context, /*enum*/) and pass that.

@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Annu Singh (AnnuCode)

Changes

Fixes #152129


Patch is 157.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152856.diff

29 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+30-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+7-7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-23)
  • (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+8-6)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-18)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+9-7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+8-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3-3)
  • (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir (+4-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+16-16)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+42-42)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+15-15)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+28-28)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+7-7)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+10-10)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-valid.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+3-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..2aafed26a4e29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -207,7 +207,7 @@ def Tosa_VariableOpBuilder : OpBuilder<
   }]>;
 
 
-// Wrapper over base I32EnumAttr to set common fields.
+ // Wrapper over base I32EnumAttr to set common fields.
 class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
      : I32EnumAttr<name, description, cases> {
    let genSpecializedAttr = 0;
@@ -276,6 +276,7 @@ def Tosa_ProfileAttr
 def Tosa_ProfileArrayAttr
     : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
 
+
 // The base class for defining op availability dimensions.
 class Availability {
   // The following are fields for controlling the generated C++ OpInterface.
@@ -381,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
   let instance = "ref";
 }
 
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_BILINEAR          : I32EnumAttrCase<"BILINEAR", 1>;
+def Tosa_RESIZE_NEAREST_NEIGHBOR  : I32EnumAttrCase<"NEAREST_NEIGHBOR", 2>;
+
+def Tosa_ResizeTypeAttr
+    : Tosa_I32EnumAttr<"ResizeType", "Supported resize/upsampling strategies", "resize_type",
+                    [Tosa_RESIZE_BILINEAR, Tosa_RESIZE_NEAREST_NEIGHBOR]>;       
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE    : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationAttr
+    : Tosa_I32EnumAttr<"NanPropagation", "Supported NaN propagation strategies", "nan",
+                    [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;  
+
+def Tosa_ROUNDING_SINGLE_ROUND    : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND   : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND    : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingTypeAttr
+    : Tosa_I32EnumAttr<"RoundingType", "Supported rounding modes", "rounding_type",
+                    [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;                         
+
+
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..fdb8f472dc060 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D: $input,
     I32Attr: $axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     Tosa_Tensor:$input,
     Tosa_IntOrFloatAttr:$min_val,
     Tosa_IntOrFloatAttr:$max_val,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640dca6561..b21ce51eb03b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,11 +13,13 @@
 #ifndef TOSA_TYPES_BASE
 #define TOSA_TYPES_BASE
 
+
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 
 include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 
+
 //===----------------------------------------------------------------------===//
 // Tosa Type Definitions.
 //===----------------------------------------------------------------------===//
@@ -234,29 +236,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
 
 def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
 
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
-    "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
-    "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
-    "Supported rounding modes">;
 
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725c7d805..4a027ccdadd61 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
-    StringRef roundingMode = op.getRoundingMode();
-    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+    RoundingType roundingMode = op.getRoundingMode();
+    if (roundingMode != RoundingType::DOUBLE_ROUND &&
+        roundingMode != RoundingType::SINGLE_ROUND) {
       return failure();
     }
 
@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
     multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
 
     // Apply double rounding if necessary.
-    if (op.getRoundingMode() == "DOUBLE_ROUND") {
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
       int64_t roundInt = 1 << 30;
       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
-    StringRef roundingMode = op.getRoundingMode();
-    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+    RoundingType roundingMode = op.getRoundingMode();
+    if (roundingMode != RoundingType::DOUBLE_ROUND &&
+        roundingMode != RoundingType::SINGLE_ROUND) {
       return failure();
     }
 
@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
         arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
 
     // Conditionally perform our double round.
-    if (op.getRoundingMode() == "DOUBLE_ROUND") {
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
       Value valuePositive = arith::CmpIOp::create(
           rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..35deeb43f51c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
     return result;
 
   auto nanMode = op.getNanMode();
-  if (nanMode == "PROPAGATE")
+  if (nanMode == NanPropagation::PROPAGATE)
     return result;
 
   // Unordered comparison of NaN against itself will always return true.
@@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         if (!b.getType().isInteger(32))
           b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
 
-        auto result = tosa::ApplyScaleOp::create(
-            rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
-            rewriter.getStringAttr("SINGLE_ROUND"));
+        auto roundingAttr = RoundingTypeAttr::get(rewriter.getContext(),
+                                                  RoundingType::SINGLE_ROUND);
+        auto result =
+            tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+                                       b, shiftConst, roundingAttr);
 
         if (elementTy.isInteger(32))
           return result;
@@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
     // In the case of "PROPAGATE" semantics no compare and selection is
     // required.
-    if (nanMode == "PROPAGATE")
+    if (nanMode == NanPropagation::PROPAGATE)
       return result;
 
     // In the case of "IGNORE" semantics materialize a comparison
@@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
                 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
     // NaN propagation has no meaning for non floating point types.
-    if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+    if (isa<FloatType>(elementTy) &&
+        op.getNanMode() == NanPropagation::IGNORE) {
       isNanIgnoreMode = true;
       // Because the TOSA spec requires the result be NaN iff all elements in
       // the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     unsigned rank = inputTy.getRank();
 
     // This is an illegal configuration. terminate and log an error
-    if (op.getRoundingMode() == "INEXACT_ROUND")
+    if (op.getRoundingMode() == RoundingType::INEXACT_ROUND)
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
               "currently supported");
-    if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND && !op.getScale32())
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
@@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // is ever true.
 
     bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
+        op.getRoundingMode() == RoundingType::DOUBLE_ROUND &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    RoundingType roundingMode =
+        doubleRound ? RoundingType::DOUBLE_ROUND : RoundingType::SINGLE_ROUND;
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
     auto input = op.getInput();
     auto inputTy = cast<RankedTensorType>(input.getType());
     auto resultTy = cast<RankedTensorType>(op.getType());
-    const bool isBilinear = op.getMode() == "BILINEAR";
+    const bool isBilinear = op.getMode() == ResizeType::BILINEAR;
 
     auto inputH = inputTy.getDimSize(1);
     auto inputW = inputTy.getDimSize(2);
@@ -1585,7 +1587,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
           op, "tosa.resize is not a pure 1x1->1x1 image operation");
 
     // TODO(suderman): These string values should be declared the TOSA dialect.
-    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+    if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+        op.getMode() != ResizeType::BILINEAR)
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
@@ -1785,7 +1788,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       return rewriter.notifyMatchFailure(
           op, "unable to get dynamic dimensions of tosa.resize");
 
-    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+    if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+        op.getMode() != ResizeType::BILINEAR)
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
@@ -1890,7 +1894,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
       }
 
-      if (op.getMode() == "NEAREST_NEIGHBOR") {
+      if (op.getMode() == ResizeType::NEAREST_NEIGHBOR) {
         auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
 
         auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1930,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         linalg::YieldOp::create(b, result);
       } else {
         // The mode here must be BILINEAR.
-        assert(op.getMode() == "BILINEAR");
+        assert(op.getMode() == ResizeType::BILINEAR);
 
         auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
 
@@ -2291,7 +2295,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
 
           Value predicate;
           if (isa<FloatType>(inElementTy)) {
-            if (argmaxOp.getNanMode() == "IGNORE") {
+            if (argmaxOp.getNanMode() == NanPropagation::IGNORE) {
               // Only update index & max value for non NaN values. If all
               // values are NaNs, the initial index will be return which is 0.
               predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca3768dd..0f738a848fcb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
         dilationAttr);
 
     rewriter.setInsertionPointAfter(op);
-    StringRef nanMode = op.getNanMode();
+    NanPropagation nanMode = op.getNanMode();
     rewriter.replaceOp(op, resultOp);
 
     // NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
     // we've already produced a named op we will just take its body and modify
     // it to include the appropriate checks. If the current value is NaN the
     // old value of pool will be taken otherwise we use the result.
-    if (nanMode == "IGNORE") {
+    if (nanMode == NanPropagation::IGNORE) {
       auto genericOp = linalg::GenericOp::create(
           rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
           resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
                 rewriter, loc, rewriter.getI8IntegerAttr(30));
             Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
 
-            auto scaled =
-                tosa::ApplyScaleOp::create(
-                    rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
-                    shift, rewriter.getStringAttr("SINGLE_ROUND"))
-                    .getResult();
+            auto roundingAttr = RoundingTypeAttr::get(
+                rewriter.getContext(), RoundingType::SINGLE_ROUND);
+
+            auto scaled = tosa::ApplyScaleOp::create(
+                              rewriter, loc, rewriter.getI32Type(), poolVal,
+                              multiplier, shift, roundingAttr)
+                              .getResult();
 
             // If we have quantization information we need to apply output
             // zeropoint.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..8e27b267c83d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     // Check we have a valid NaN propagation combination.
     const auto opNanMode = op.getNanMode();
     const auto clampNanMode = clampOp.getNanMode();
-    if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+    if (opNanMode == NanPropagation::IGNORE &&
+        clampNanMode == NanPropagation::PROPAGATE)
       return failure();
 
     auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +637,14 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
       }
     }
 
+    auto newMode =
+        (opNanMode != clampNanMode) ? tosa::NanPropagation::IGNORE : opNanMode;
+
+    auto newModeAttr = NanPropagationAttr::get(rewriter.getContext(), newMode);
+
     rewriter.replaceOpWithNewOp<tosa::ClampOp>(
         op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
-        rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
-                                                           : opNanMode));
+        newModeAttr);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..5c04874e494c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,13 +508,13...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2025

@llvm/pr-subscribers-mlir

Author: Annu Singh (AnnuCode)

Changes

Fixes #152129


Patch is 157.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152856.diff

29 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+30-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+7-7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-23)
  • (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+8-6)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-18)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+9-7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+8-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3-3)
  • (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir (+4-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+16-16)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+42-42)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+15-15)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+28-28)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+7-7)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+10-10)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-valid.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+3-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..2aafed26a4e29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -207,7 +207,7 @@ def Tosa_VariableOpBuilder : OpBuilder<
   }]>;
 
 
-// Wrapper over base I32EnumAttr to set common fields.
+ // Wrapper over base I32EnumAttr to set common fields.
 class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
      : I32EnumAttr<name, description, cases> {
    let genSpecializedAttr = 0;
@@ -276,6 +276,7 @@ def Tosa_ProfileAttr
 def Tosa_ProfileArrayAttr
     : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
 
+
 // The base class for defining op availability dimensions.
 class Availability {
   // The following are fields for controlling the generated C++ OpInterface.
@@ -381,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
   let instance = "ref";
 }
 
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_BILINEAR          : I32EnumAttrCase<"BILINEAR", 1>;
+def Tosa_RESIZE_NEAREST_NEIGHBOR  : I32EnumAttrCase<"NEAREST_NEIGHBOR", 2>;
+
+def Tosa_ResizeTypeAttr
+    : Tosa_I32EnumAttr<"ResizeType", "Supported resize/upsampling strategies", "resize_type",
+                    [Tosa_RESIZE_BILINEAR, Tosa_RESIZE_NEAREST_NEIGHBOR]>;       
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE    : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationAttr
+    : Tosa_I32EnumAttr<"NanPropagation", "Supported NaN propagation strategies", "nan",
+                    [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;  
+
+def Tosa_ROUNDING_SINGLE_ROUND    : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND   : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND    : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingTypeAttr
+    : Tosa_I32EnumAttr<"RoundingType", "Supported rounding modes", "rounding_type",
+                    [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;                         
+
+
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..fdb8f472dc060 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D: $input,
     I32Attr: $axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     Tosa_Tensor:$input,
     Tosa_IntOrFloatAttr:$min_val,
     Tosa_IntOrFloatAttr:$max_val,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640dca6561..b21ce51eb03b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,11 +13,13 @@
 #ifndef TOSA_TYPES_BASE
 #define TOSA_TYPES_BASE
 
+
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 
 include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 
+
 //===----------------------------------------------------------------------===//
 // Tosa Type Definitions.
 //===----------------------------------------------------------------------===//
@@ -234,29 +236,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
 
 def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
 
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
-    "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
-    "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
-    "Supported rounding modes">;
 
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725c7d805..4a027ccdadd61 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
-    StringRef roundingMode = op.getRoundingMode();
-    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+    RoundingType roundingMode = op.getRoundingMode();
+    if (roundingMode != RoundingType::DOUBLE_ROUND &&
+        roundingMode != RoundingType::SINGLE_ROUND) {
       return failure();
     }
 
@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
     multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
 
     // Apply double rounding if necessary.
-    if (op.getRoundingMode() == "DOUBLE_ROUND") {
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
       int64_t roundInt = 1 << 30;
       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
-    StringRef roundingMode = op.getRoundingMode();
-    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+    RoundingType roundingMode = op.getRoundingMode();
+    if (roundingMode != RoundingType::DOUBLE_ROUND &&
+        roundingMode != RoundingType::SINGLE_ROUND) {
       return failure();
     }
 
@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
         arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
 
     // Conditionally perform our double round.
-    if (op.getRoundingMode() == "DOUBLE_ROUND") {
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
       Value valuePositive = arith::CmpIOp::create(
           rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..35deeb43f51c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
     return result;
 
   auto nanMode = op.getNanMode();
-  if (nanMode == "PROPAGATE")
+  if (nanMode == NanPropagation::PROPAGATE)
     return result;
 
   // Unordered comparison of NaN against itself will always return true.
@@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         if (!b.getType().isInteger(32))
           b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
 
-        auto result = tosa::ApplyScaleOp::create(
-            rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
-            rewriter.getStringAttr("SINGLE_ROUND"));
+        auto roundingAttr = RoundingTypeAttr::get(rewriter.getContext(),
+                                                  RoundingType::SINGLE_ROUND);
+        auto result =
+            tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+                                       b, shiftConst, roundingAttr);
 
         if (elementTy.isInteger(32))
           return result;
@@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
     // In the case of "PROPAGATE" semantics no compare and selection is
     // required.
-    if (nanMode == "PROPAGATE")
+    if (nanMode == NanPropagation::PROPAGATE)
       return result;
 
     // In the case of "IGNORE" semantics materialize a comparison
@@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
                 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
     // NaN propagation has no meaning for non floating point types.
-    if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+    if (isa<FloatType>(elementTy) &&
+        op.getNanMode() == NanPropagation::IGNORE) {
       isNanIgnoreMode = true;
       // Because the TOSA spec requires the result be NaN iff all elements in
       // the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     unsigned rank = inputTy.getRank();
 
     // This is an illegal configuration. terminate and log an error
-    if (op.getRoundingMode() == "INEXACT_ROUND")
+    if (op.getRoundingMode() == RoundingType::INEXACT_ROUND)
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
               "currently supported");
-    if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND && !op.getScale32())
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
@@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // is ever true.
 
     bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
+        op.getRoundingMode() == RoundingType::DOUBLE_ROUND &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    RoundingType roundingMode =
+        doubleRound ? RoundingType::DOUBLE_ROUND : RoundingType::SINGLE_ROUND;
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
     auto input = op.getInput();
     auto inputTy = cast<RankedTensorType>(input.getType());
     auto resultTy = cast<RankedTensorType>(op.getType());
-    const bool isBilinear = op.getMode() == "BILINEAR";
+    const bool isBilinear = op.getMode() == ResizeType::BILINEAR;
 
     auto inputH = inputTy.getDimSize(1);
     auto inputW = inputTy.getDimSize(2);
@@ -1585,7 +1587,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
           op, "tosa.resize is not a pure 1x1->1x1 image operation");
 
     // TODO(suderman): These string values should be declared the TOSA dialect.
-    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+    if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+        op.getMode() != ResizeType::BILINEAR)
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
@@ -1785,7 +1788,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       return rewriter.notifyMatchFailure(
           op, "unable to get dynamic dimensions of tosa.resize");
 
-    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+    if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+        op.getMode() != ResizeType::BILINEAR)
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
@@ -1890,7 +1894,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
       }
 
-      if (op.getMode() == "NEAREST_NEIGHBOR") {
+      if (op.getMode() == ResizeType::NEAREST_NEIGHBOR) {
         auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
 
         auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1930,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         linalg::YieldOp::create(b, result);
       } else {
         // The mode here must be BILINEAR.
-        assert(op.getMode() == "BILINEAR");
+        assert(op.getMode() == ResizeType::BILINEAR);
 
         auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
 
@@ -2291,7 +2295,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
 
           Value predicate;
           if (isa<FloatType>(inElementTy)) {
-            if (argmaxOp.getNanMode() == "IGNORE") {
+            if (argmaxOp.getNanMode() == NanPropagation::IGNORE) {
               // Only update index & max value for non NaN values. If all
               // values are NaNs, the initial index will be return which is 0.
               predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca3768dd..0f738a848fcb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
         dilationAttr);
 
     rewriter.setInsertionPointAfter(op);
-    StringRef nanMode = op.getNanMode();
+    NanPropagation nanMode = op.getNanMode();
     rewriter.replaceOp(op, resultOp);
 
     // NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
     // we've already produced a named op we will just take its body and modify
     // it to include the appropriate checks. If the current value is NaN the
     // old value of pool will be taken otherwise we use the result.
-    if (nanMode == "IGNORE") {
+    if (nanMode == NanPropagation::IGNORE) {
       auto genericOp = linalg::GenericOp::create(
           rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
           resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
                 rewriter, loc, rewriter.getI8IntegerAttr(30));
             Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
 
-            auto scaled =
-                tosa::ApplyScaleOp::create(
-                    rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
-                    shift, rewriter.getStringAttr("SINGLE_ROUND"))
-                    .getResult();
+            auto roundingAttr = RoundingTypeAttr::get(
+                rewriter.getContext(), RoundingType::SINGLE_ROUND);
+
+            auto scaled = tosa::ApplyScaleOp::create(
+                              rewriter, loc, rewriter.getI32Type(), poolVal,
+                              multiplier, shift, roundingAttr)
+                              .getResult();
 
             // If we have quantization information we need to apply output
             // zeropoint.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..8e27b267c83d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     // Check we have a valid NaN propagation combination.
     const auto opNanMode = op.getNanMode();
     const auto clampNanMode = clampOp.getNanMode();
-    if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+    if (opNanMode == NanPropagation::IGNORE &&
+        clampNanMode == NanPropagation::PROPAGATE)
       return failure();
 
     auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +637,14 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
       }
     }
 
+    auto newMode =
+        (opNanMode != clampNanMode) ? tosa::NanPropagation::IGNORE : opNanMode;
+
+    auto newModeAttr = NanPropagationAttr::get(rewriter.getContext(), newMode);
+
     rewriter.replaceOpWithNewOp<tosa::ClampOp>(
         op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
-        rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
-                                                           : opNanMode));
+        newModeAttr);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..5c04874e494c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,13 +508,13...
[truncated]

Copy link
Contributor

@psunn psunn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this.

It would be better to update your commit message, as this change could be considered a breaking change for others. Could you please rewrite the commit message with more detailed context to support future development? For example, you might consider using a revised description from #152129

psunn

This comment was marked as duplicate.

@psunn psunn requested a review from lhutton1 August 11, 2025 21:54
… enumerations.

This PR replaces `StringBasedAttr` with `Tosa_I32EnumAttr` to represent Tosa enumerations as per the specification.
The intent is to make the IR and C++ APIs more type-safe and prevent fragile string comparisons in passes.

Enumerations rewritten are:

- `Tosa_ResizeTypeAttr`
- `Tosa_NanPropagationAttr`
- `Tosa_RoundingTypeAttr`

BREAKING CHANGE:

This commit changes attribute assembly and the C++ API surface for the listed attributes.
Code that previously used `StringAttr` for these fields must now be updated to use the new enum representation.
In `.mlir` files, replace string literals with the enum assembly (e.g. `mode = #tosa.resize_type<BILINEAR>`).
In C++, update call sites to either pass the generated enum (e.g. `::mlir::tosa::RoundingType::SINGLE_ROUND`) into builder overloads or construct the typed attribute with `tosa::RoundingTypeAttr::get(context, /*enum*/)` and pass that.
@AnnuCode
Copy link
Contributor Author

@psunn, thank you for the corrections.

@lhutton1
Copy link
Contributor

Thanks for the patch @AnnuCode! Overall the changes look great. Since this is a breaking change, we have a few dependencies which we would like to update at the same time so that we don't break builds. I'm working on this, but would like to wait till these changes are ready before merging this patch

@joker-eph joker-eph changed the title [mlir][tosa] StringBasedAttr TOSA enumerations to Tosa_I32EnumAttr [mlir][tosa] Convert TOSA enumerations from StringBasedAttr to Tosa_I32EnumAttr Aug 13, 2025
Copy link

github-actions bot commented Aug 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][tosa] Replace StringBasedAttr TOSA enumerations with Tosa_I32EnumAttr
4 participants