-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][tosa] Add the concept of a TOSA target environment #153771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit introduces a new module-level attribute `tosa.target_env`. IT encapsulates target information for use during compilation such as: level, profiles and extensions. For example: ```mlir module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int], extensions = [int16, int4]>} { <my-tosa-program> } ``` Previously the validation pass accepted target infomation as a series of command line pass options. This commit changes the behaviour to query the attached target environment from the module attribute. This refactoring allows other passes to query the same target information. A new target environment can be atached using the `--tosa-attach-target` pass, which takes the same command line options as the previous validation pass arguments. For example: ```bash mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir ``` Change-Id: I74a254855f6320dc70b29ae3509997764e3e5d95
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Luke Hutton (lhutton1) ChangesThis commit introduces a new module-level attribute module attributes {tosa.target_env =
#tosa.target_env<level = none, profiles = [pro_int], extensions = [int16, int4]>} {
<my-tosa-program>
} Previously the validation pass accepted target information as a series of command line pass options. This commit changes the behaviour to query the attached target environment from the module attribute. This refactoring allows other passes to query the same target information. A new target environment can be attached using the mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir Patch is 37.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153771.diff 22 Files Affected:
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index f4823858e3893..ab9b9f24ef3dd 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
- tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
+ tosa::TosaValidationOptions{false, false});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 9ee5079559d2b..10491f65d37af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -20,24 +20,67 @@
namespace mlir {
namespace tosa {
+struct TosaLevel {
+ int32_t MAX_RANK = 0;
+ int32_t MAX_KERNEL = 0;
+ int32_t MAX_STRIDE = 0;
+ int32_t MAX_SCALE = 0;
+ int32_t MAX_LOG2_SIZE = 0;
+ int32_t MAX_NESTING = 0;
+ int32_t MAX_TENSOR_LIST_SIZE = 0;
+
+ bool operator==(const TosaLevel &rhs) {
+ return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
+ MAX_NESTING == rhs.MAX_NESTING &&
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ }
+};
+
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
+ 63, 256, 256};
+
+TargetEnvAttr lookupTargetEnv(Operation *op);
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op` or returns the default target environment as
+/// returned by getDefaultTargetEnv() if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+
/// This class represents the capability enabled in the target implementation
-/// such as profile, extension, and level.
+/// such as profile, extension, and level. It's a wrapper class around
+/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
- const SmallVectorImpl<Extension> &extensions) {
+ explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ const ArrayRef<Extension> &extensions)
+ : level(level) {
enabledProfiles.insert_range(profiles);
-
enabledExtensions.insert_range(extensions);
}
+ explicit TargetEnv(TargetEnvAttr targetAttr)
+ : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
+ targetAttr.getExtensions()) {}
+
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
// TODO implement the following utilities.
// Version getSpecVersion() const;
- // TosaLevel getLevel() const;
+
+ TosaLevel getLevel() const {
+ if (level == Level::eightK)
+ return TOSA_LEVEL_EIGHTK;
+ else if (level == Level::none)
+ return TOSA_LEVEL_NONE;
+ else
+ llvm_unreachable("Unknown TOSA level");
+ };
// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -62,8 +105,9 @@ class TargetEnv {
}
private:
+ Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
- llvm::SmallSet<Extension, 8> enabledExtensions;
+ llvm::SmallSet<Extension, 13> enabledExtensions;
};
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..9e9ff21db1fa1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
+def Tosa_ProfileAttr
+ : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
+ [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Profile, 2> getAllValues() {
+ return {Profile::pro_int, Profile::pro_fp};
+ }
+ }];
+}
+
+def Tosa_ProfileArrayAttr
+ : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
@@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
- ]>;
+ ]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Extension, 11> getAllValues() {
+ return {
+ Extension::int16, Extension::int4, Extension::bf16,
+ Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
+ Extension::variable, Extension::controlflow, Extension::doubleround,
+ Extension::inexactround, Extension::dynamic
+ };
+ }
+ }];
+}
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_ProfileAttr
- : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
- [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-def Tosa_ProfileArrayAttr
- : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
// The base class for defining op availability dimensions.
class Availability {
@@ -381,6 +404,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// TOSA target environment.
+//===----------------------------------------------------------------------===//
+def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
+ let summary = "Target environment information.";
+ let parameters = ( ins
+ "Level": $level,
+ ArrayRefParameter<"Profile">: $profiles,
+ ArrayRefParameter<"Extension">: $extensions
+ );
+
+ let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ "`extensions` `=` `[` $extensions `]` `>`";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index d4e2661838314..b1363b5a179df 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,7 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
-mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
-mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTosaPassIncGen)
add_dependencies(mlir-headers MLIRTosaPassIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1f218e7..ba99d2f1d2727 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index b96682843538c..6ae19d81e0820 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
}];
}
-def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
- [
- I32EnumAttrCase<"None", 0, "none">,
- I32EnumAttrCase<"EightK", 1, "8k">,
- ]>{
- let cppNamespace = "mlir::tosa";
-}
-
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
@@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
}];
let options = [
- ListOption<"profile", "profile", "std::string",
- "Validate if operations match for the given profile set">,
- ListOption<"extension", "extension", "std::string",
- "Validate if operations match for the given extension set">,
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
@@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
- "sections of the specifciation">,
- Option<"level", "level", "mlir::tosa::TosaLevelEnum",
- /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
- "Validate if operator parameters are within specfication for the given level",
- [{::llvm::cl::values(
- clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
- "Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
- clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
- "Allows the full range of arguments specified by the operations according "
- "to the operation data types.")
- )}]>
+ "sections of the specifciation">
];
}
@@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
}];
}
+def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
+ let summary = "Attach tosa.target_env information to the given module.";
+
+ let description = [{
+ This pass allows the user to specify a TOSA target environment consisting of
+ the following components: level, profiles and extensions.
+
+ The target environment is attached to the module as an attribute, allowing other
+ transformations to query the selected target and adapt their behaviour based on
+ this information.
+ }];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+
+ let options = [
+ Option<"level", "level", "mlir::tosa::Level",
+ /*default=*/"mlir::tosa::Level::eightK",
+ "The TOSA level that operators should conform to. A TOSA level defines "
+ "operator argument ranges that an implementation shall support.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::Level::eightK, "8k",
+ "Ranges are expected to be sufficient for applications with frame "
+ "sizes up to 8K."),
+ clEnumValN(mlir::tosa::Level::none, "none",
+ "Allows the full range of arguments specified by the operations according "
+ "to the operation data types.")
+ )}]>,
+ ListOption<"profiles", "profiles", "std::string",
+ "The TOSA profile(s) that operators should conform to. TOSA profiles "
+ "enable efficient implementation on different classes of device. Each "
+ "profile is an independent set of operations and data type combinations.">,
+ ListOption<"extensions", "extensions", "std::string",
+ "The TOSA extension(s) that operators should conform to. TOSA profile "
+ "extensions define optional operation and data type combinations.">
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index c6a3ba9f1439f..e7602b4508cf1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
TosaValidationOptions validationOptions;
- validationOptions.profile = {"none"};
- validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
- validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d348946..a95906aa8352e 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000000000..5aad67173cc61
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993bb1008d..41b338d6e7189 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000000000..bcb880a808b36
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..f2dcef15d8517 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_...
[truncated]
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThis commit introduces a new module-level attribute module attributes {tosa.target_env =
#tosa.target_env<level = none, profiles = [pro_int], extensions = [int16, int4]>} {
<my-tosa-program>
} Previously the validation pass accepted target information as a series of command line pass options. This commit changes the behaviour to query the attached target environment from the module attribute. This refactoring allows other passes to query the same target information. A new target environment can be attached using the mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir Patch is 37.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153771.diff 22 Files Affected:
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index f4823858e3893..ab9b9f24ef3dd 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
- tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
+ tosa::TosaValidationOptions{false, false});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 9ee5079559d2b..10491f65d37af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -20,24 +20,67 @@
namespace mlir {
namespace tosa {
+struct TosaLevel {
+ int32_t MAX_RANK = 0;
+ int32_t MAX_KERNEL = 0;
+ int32_t MAX_STRIDE = 0;
+ int32_t MAX_SCALE = 0;
+ int32_t MAX_LOG2_SIZE = 0;
+ int32_t MAX_NESTING = 0;
+ int32_t MAX_TENSOR_LIST_SIZE = 0;
+
+ bool operator==(const TosaLevel &rhs) {
+ return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
+ MAX_NESTING == rhs.MAX_NESTING &&
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ }
+};
+
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
+ 63, 256, 256};
+
+TargetEnvAttr lookupTargetEnv(Operation *op);
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op` or returns the default target environment as
+/// returned by getDefaultTargetEnv() if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+
/// This class represents the capability enabled in the target implementation
-/// such as profile, extension, and level.
+/// such as profile, extension, and level. It's a wrapper class around
+/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
- const SmallVectorImpl<Extension> &extensions) {
+ explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ const ArrayRef<Extension> &extensions)
+ : level(level) {
enabledProfiles.insert_range(profiles);
-
enabledExtensions.insert_range(extensions);
}
+ explicit TargetEnv(TargetEnvAttr targetAttr)
+ : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
+ targetAttr.getExtensions()) {}
+
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
// TODO implement the following utilities.
// Version getSpecVersion() const;
- // TosaLevel getLevel() const;
+
+ TosaLevel getLevel() const {
+ if (level == Level::eightK)
+ return TOSA_LEVEL_EIGHTK;
+ else if (level == Level::none)
+ return TOSA_LEVEL_NONE;
+ else
+ llvm_unreachable("Unknown TOSA level");
+ };
// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -62,8 +105,9 @@ class TargetEnv {
}
private:
+ Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
- llvm::SmallSet<Extension, 8> enabledExtensions;
+ llvm::SmallSet<Extension, 13> enabledExtensions;
};
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..9e9ff21db1fa1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
+def Tosa_ProfileAttr
+ : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
+ [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Profile, 2> getAllValues() {
+ return {Profile::pro_int, Profile::pro_fp};
+ }
+ }];
+}
+
+def Tosa_ProfileArrayAttr
+ : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
@@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
- ]>;
+ ]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Extension, 11> getAllValues() {
+ return {
+ Extension::int16, Extension::int4, Extension::bf16,
+ Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
+ Extension::variable, Extension::controlflow, Extension::doubleround,
+ Extension::inexactround, Extension::dynamic
+ };
+ }
+ }];
+}
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_ProfileAttr
- : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
- [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-def Tosa_ProfileArrayAttr
- : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
// The base class for defining op availability dimensions.
class Availability {
@@ -381,6 +404,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// TOSA target environment.
+//===----------------------------------------------------------------------===//
+def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
+ let summary = "Target environment information.";
+ let parameters = ( ins
+ "Level": $level,
+ ArrayRefParameter<"Profile">: $profiles,
+ ArrayRefParameter<"Extension">: $extensions
+ );
+
+ let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ "`extensions` `=` `[` $extensions `]` `>`";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index d4e2661838314..b1363b5a179df 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,7 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
-mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
-mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTosaPassIncGen)
add_dependencies(mlir-headers MLIRTosaPassIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1f218e7..ba99d2f1d2727 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index b96682843538c..6ae19d81e0820 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
}];
}
-def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
- [
- I32EnumAttrCase<"None", 0, "none">,
- I32EnumAttrCase<"EightK", 1, "8k">,
- ]>{
- let cppNamespace = "mlir::tosa";
-}
-
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
@@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
}];
let options = [
- ListOption<"profile", "profile", "std::string",
- "Validate if operations match for the given profile set">,
- ListOption<"extension", "extension", "std::string",
- "Validate if operations match for the given extension set">,
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
@@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
- "sections of the specifciation">,
- Option<"level", "level", "mlir::tosa::TosaLevelEnum",
- /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
- "Validate if operator parameters are within specfication for the given level",
- [{::llvm::cl::values(
- clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
- "Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
- clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
- "Allows the full range of arguments specified by the operations according "
- "to the operation data types.")
- )}]>
+ "sections of the specifciation">
];
}
@@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
}];
}
+def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
+ let summary = "Attach tosa.target_env information to the given module.";
+
+ let description = [{
+ This pass allows the user to specify a TOSA target environment consisting of
+ the following components: level, profiles and extensions.
+
+ The target environment is attached to the module as an attribute, allowing other
+ transformations to query the selected target and adapt their behaviour based on
+ this information.
+ }];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+
+ let options = [
+ Option<"level", "level", "mlir::tosa::Level",
+ /*default=*/"mlir::tosa::Level::eightK",
+ "The TOSA level that operators should conform to. A TOSA level defines "
+ "operator argument ranges that an implementation shall support.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::Level::eightK, "8k",
+ "Ranges are expected to be sufficient for applications with frame "
+ "sizes up to 8K."),
+ clEnumValN(mlir::tosa::Level::none, "none",
+ "Allows the full range of arguments specified by the operations according "
+ "to the operation data types.")
+ )}]>,
+ ListOption<"profiles", "profiles", "std::string",
+ "The TOSA profile(s) that operators should conform to. TOSA profiles "
+ "enable efficient implementation on different classes of device. Each "
+ "profile is an independent set of operations and data type combinations.">,
+ ListOption<"extensions", "extensions", "std::string",
+ "The TOSA extension(s) that operators should conform to. TOSA profile "
+ "extensions define optional operation and data type combinations.">
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index c6a3ba9f1439f..e7602b4508cf1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
TosaValidationOptions validationOptions;
- validationOptions.profile = {"none"};
- validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
- validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d348946..a95906aa8352e 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000000000..5aad67173cc61
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993bb1008d..41b338d6e7189 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000000000..bcb880a808b36
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..f2dcef15d8517 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_...
[truncated]
|
This commit introduces a new module-level attribute
tosa.target_env
. It encapsulates target information for use during compilation such as: level, profiles and extensions. For example:Previously the validation pass accepted target information as a series of command line pass options. This commit changes the behaviour to query the attached target environment from the module attribute. This refactoring allows other passes to query the same target information.
A new target environment can be attached using the
--tosa-attach-target
pass, which takes the same command line options as the previous validation pass arguments. For example:mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir