Skip to content

[mlir][tosa] Convert TOSA enumerations from StringBasedAttr to Tosa_I32EnumAttr #152856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}

//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.

def Tosa_RESIZE_NEAREST_NEIGHBOR : I32EnumAttrCase<"NEAREST_NEIGHBOR", 1>;
def Tosa_RESIZE_BILINEAR : I32EnumAttrCase<"BILINEAR", 2>;

def Tosa_ResizeModeAttr
: Tosa_I32EnumAttr<"ResizeMode", "Supported resize/upsampling strategies", "resize_mode",
[Tosa_RESIZE_NEAREST_NEIGHBOR, Tosa_RESIZE_BILINEAR]>;

def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
def Tosa_NANPROPAGATION_IGNORE : I32EnumAttrCase<"IGNORE", 2>;

def Tosa_NanPropagationModeAttr
: Tosa_I32EnumAttr<"NanPropagationMode", "Supported NaN propagation strategies", "nan_mode",
[Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;

def Tosa_ROUNDING_SINGLE_ROUND : I32EnumAttrCase<"SINGLE_ROUND", 1>;
def Tosa_ROUNDING_INEXACT_ROUND : I32EnumAttrCase<"INEXACT_ROUND", 2>;
def Tosa_ROUNDING_DOUBLE_ROUND : I32EnumAttrCase<"DOUBLE_ROUND", 3>;

def Tosa_RoundingModeAttr
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;


//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 9 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let arguments = (ins
Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
Tosa_Tensor:$input,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);

let results = (outs
Expand Down Expand Up @@ -2224,7 +2224,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
Rank4TosaShape:$scale,
Rank2TosaShape:$offset,
Rank2TosaShape:$border,
Tosa_ResizeTypeAttr:$mode
Tosa_ResizeModeAttr:$mode
);

let results = (outs
Expand Down Expand Up @@ -2374,7 +2374,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp,
BoolAttr:$scale32,
Tosa_RoundingTypeAttr:$rounding_mode,
Tosa_RoundingModeAttr:$rounding_mode,
BoolAttr:$per_channel,
BoolAttr: $input_unsigned,
BoolAttr: $output_unsigned
Expand Down
23 changes: 0 additions & 23 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -234,29 +234,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,

def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;

//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.

// Supported regimes for tosa.resize.
def Tosa_ResizeTypeAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
"Supported resize/upsampling strategies">;

// Supported NaN propagation strategies.
def Tosa_NanPropagationAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
"Supported NaN propagation strategies">;

// Rounding mode for tosa.rescale
def Tosa_RoundingTypeAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
"Supported rounding modes">;

def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
Tosa_IntLike:$value,
Tosa_IntLike:$multiplier,
Tosa_Int8Like:$shift,
Tosa_RoundingTypeAttr:$rounding_mode
Tosa_RoundingModeAttr:$rounding_mode
);

let results = (outs
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter

LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
StringRef roundingMode = op.getRoundingMode();
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
RoundingMode roundingMode = op.getRoundingMode();
if (roundingMode != RoundingMode::DOUBLE_ROUND &&
roundingMode != RoundingMode::SINGLE_ROUND) {
return failure();
}

Expand Down Expand Up @@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);

// Apply double rounding if necessary.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
Expand Down Expand Up @@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {

LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
StringRef roundingMode = op.getRoundingMode();
if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
RoundingMode roundingMode = op.getRoundingMode();
if (roundingMode != RoundingMode::DOUBLE_ROUND &&
roundingMode != RoundingMode::SINGLE_ROUND) {
return failure();
}

Expand Down Expand Up @@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);

// Conditionally perform our double round.
if (op.getRoundingMode() == "DOUBLE_ROUND") {
if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
Expand Down
41 changes: 22 additions & 19 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;

auto nanMode = op.getNanMode();
if (nanMode == "PROPAGATE")
if (nanMode == NanPropagationMode::PROPAGATE)
return result;

// Unordered comparison of NaN against itself will always return true.
Expand Down Expand Up @@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);

auto result = tosa::ApplyScaleOp::create(
rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getStringAttr("SINGLE_ROUND"));
auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
RoundingMode::SINGLE_ROUND);
auto result =
tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
b, shiftConst, roundingAttr);

if (elementTy.isInteger(32))
return result;
Expand Down Expand Up @@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(

// In the case of "PROPAGATE" semantics no compare and selection is
// required.
if (nanMode == "PROPAGATE")
if (nanMode == NanPropagationMode::PROPAGATE)
return result;

// In the case of "IGNORE" semantics materialize a comparison
Expand Down Expand Up @@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
if (isa<FloatType>(elementTy) &&
op.getNanMode() == NanPropagationMode::IGNORE) {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
Expand Down Expand Up @@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();

// This is an illegal configuration. terminate and log an error
if (op.getRoundingMode() == "INEXACT_ROUND")
if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");

Expand Down Expand Up @@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// is ever true.

bool doubleRound =
op.getRoundingMode() == "DOUBLE_ROUND" &&
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
StringAttr roundingMode = doubleRound
? rewriter.getStringAttr("DOUBLE_ROUND")
: rewriter.getStringAttr("SINGLE_ROUND");
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;

SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
Expand Down Expand Up @@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
const bool isBilinear = op.getMode() == "BILINEAR";
const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;

auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
Expand All @@ -1584,8 +1586,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "tosa.resize is not a pure 1x1->1x1 image operation");

// TODO(suderman): These string values should be declared the TOSA dialect.
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
op.getMode() != ResizeMode::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");

Expand Down Expand Up @@ -1785,7 +1787,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");

if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
op.getMode() != ResizeMode::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");

Expand Down Expand Up @@ -1890,7 +1893,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}

if (op.getMode() == "NEAREST_NEIGHBOR") {
if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));

auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
Expand Down Expand Up @@ -1926,7 +1929,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
assert(op.getMode() == "BILINEAR");
assert(op.getMode() == ResizeMode::BILINEAR);

auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));

Expand Down Expand Up @@ -2291,7 +2294,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {

Value predicate;
if (isa<FloatType>(inElementTy)) {
if (argmaxOp.getNanMode() == "IGNORE") {
if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = arith::CmpFOp::create(rewriter, nestedLoc,
Expand Down
16 changes: 9 additions & 7 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
dilationAttr);

rewriter.setInsertionPointAfter(op);
StringRef nanMode = op.getNanMode();
NanPropagationMode nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);

// NaN propagation has no meaning for non floating point types.
Expand All @@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
if (nanMode == "IGNORE") {
if (nanMode == NanPropagationMode::IGNORE) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
Expand Down Expand Up @@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
rewriter, loc, rewriter.getI8IntegerAttr(30));
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);

auto scaled =
tosa::ApplyScaleOp::create(
rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
shift, rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
auto roundingAttr = RoundingModeAttr::get(
rewriter.getContext(), RoundingMode::SINGLE_ROUND);

auto scaled = tosa::ApplyScaleOp::create(
rewriter, loc, rewriter.getI32Type(), poolVal,
multiplier, shift, roundingAttr)
.getResult();

// If we have quantization information we need to apply output
// zeropoint.
Expand Down
13 changes: 10 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
if (opNanMode == NanPropagationMode::IGNORE &&
clampNanMode == NanPropagationMode::PROPAGATE)
return failure();

auto maxValAttr = op.getMaxValAttr();
Expand Down Expand Up @@ -636,10 +637,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}

auto newMode = (opNanMode != clampNanMode)
? tosa::NanPropagationMode::IGNORE
: opNanMode;

auto newModeAttr =
NanPropagationModeAttr::get(rewriter.getContext(), newMode);

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
: opNanMode));
newModeAttr);
return success();
}
};
Expand Down
Loading