diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index f4823858e3893..ab9b9f24ef3dd 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -39,8 +39,7 @@ void addTosaToLinalgPasses( TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. std::optional validationOptions = - tosa::TosaValidationOptions{ - {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None}); + tosa::TosaValidationOptions{false, false}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 9ee5079559d2b..10491f65d37af 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -20,24 +20,67 @@ namespace mlir { namespace tosa { +struct TosaLevel { + int32_t MAX_RANK = 0; + int32_t MAX_KERNEL = 0; + int32_t MAX_STRIDE = 0; + int32_t MAX_SCALE = 0; + int32_t MAX_LOG2_SIZE = 0; + int32_t MAX_NESTING = 0; + int32_t MAX_TENSOR_LIST_SIZE = 0; + + bool operator==(const TosaLevel &rhs) { + return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && + MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && + MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && + MAX_NESTING == rhs.MAX_NESTING && + MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; + } +}; + +static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; +static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, + 63, 256, 256}; + +TargetEnvAttr lookupTargetEnv(Operation *op); +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); + +/// Queries the target environment recursively from enclosing symbol table ops +/// containing the given `op` or returns the default target environment as +/// returned by getDefaultTargetEnv() if not provided. +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); + /// This class represents the capability enabled in the target implementation -/// such as profile, extension, and level. +/// such as profile, extension, and level. It's a wrapper class around +/// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(const SmallVectorImpl &profiles, - const SmallVectorImpl &extensions) { + explicit TargetEnv(Level level, const ArrayRef &profiles, + const ArrayRef &extensions) + : level(level) { enabledProfiles.insert_range(profiles); - enabledExtensions.insert_range(extensions); } + explicit TargetEnv(TargetEnvAttr targetAttr) + : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), + targetAttr.getExtensions()) {} + void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } // TODO implement the following utilities. // Version getSpecVersion() const; - // TosaLevel getLevel() const; + + TosaLevel getLevel() const { + if (level == Level::eightK) + return TOSA_LEVEL_EIGHTK; + else if (level == Level::none) + return TOSA_LEVEL_NONE; + else + llvm_unreachable("Unknown TOSA level"); + }; // Returns true if the given profile is allowed. bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; } @@ -62,8 +105,9 @@ class TargetEnv { } private: + Level level; llvm::SmallSet enabledProfiles; - llvm::SmallSet enabledExtensions; + llvm::SmallSet enabledExtensions; }; } // namespace tosa diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index e048f8af7cc33..9e9ff21db1fa1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>; def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>; def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>; +def Tosa_ProfileAttr + : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", + [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> { + let extraClassDeclaration = [{ + static llvm::SmallVector getAllValues() { + return {Profile::pro_int, Profile::pro_fp}; + } + }]; +} + +def Tosa_ProfileArrayAttr + : TypedArrayAttrBase; + def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>; def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; @@ -264,17 +277,27 @@ def Tosa_ExtensionAttr Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC - ]>; + ]> { + let extraClassDeclaration = [{ + static llvm::SmallVector getAllValues() { + return { + Extension::int16, Extension::int4, Extension::bf16, + Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, + Extension::variable, Extension::controlflow, Extension::doubleround, + Extension::inexactround, Extension::dynamic + }; + } + }]; +} def Tosa_ExtensionArrayAttr : TypedArrayAttrBase; -def Tosa_ProfileAttr - : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", - [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>; +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; -def Tosa_ProfileArrayAttr - : TypedArrayAttrBase; +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; // The base class for defining op availability dimensions. class Availability { @@ -381,6 +404,21 @@ class Extension extensions> : Availability { let instance = "ref"; } +//===----------------------------------------------------------------------===// +// TOSA target environment. +//===----------------------------------------------------------------------===// +def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { + let summary = "Target environment information."; + let parameters = ( ins + "Level": $level, + ArrayRefParameter<"Profile">: $profiles, + ArrayRefParameter<"Extension">: $extensions + ); + + let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + "`extensions` `=` `[` $extensions `]` `>`"; +} + //===----------------------------------------------------------------------===// // TOSA Interfaces. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt index d4e2661838314..b1363b5a179df 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,7 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) -mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) -mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRTosaPassIncGen) add_dependencies(mlir-headers MLIRTosaPassIncGen) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 306e4b1f218e7..ba99d2f1d2727 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -15,7 +15,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index b96682843538c..6ae19d81e0820 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass }]; } -def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", - [ - I32EnumAttrCase<"None", 0, "none">, - I32EnumAttrCase<"EightK", 1, "8k">, - ]>{ - let cppNamespace = "mlir::tosa"; -} - def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { let summary = "Validates TOSA dialect"; let description = [{ @@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { }]; let options = [ - ListOption<"profile", "profile", "std::string", - "Validate if operations match for the given profile set">, - ListOption<"extension", "extension", "std::string", - "Validate if operations match for the given extension set">, Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, @@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { /*default=*/"false", "Disable checks for operations that are determined to be invalid due to their " "operand/result datatypes not aligning with the 'Supported Data Types' " - "sections of the specifciation">, - Option<"level", "level", "mlir::tosa::TosaLevelEnum", - /*default=*/"mlir::tosa::TosaLevelEnum::EightK", - "Validate if operator parameters are within specfication for the given level", - [{::llvm::cl::values( - clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k", - "Ranges are expected to be sufficient for applications with frame sizes up to 8K."), - clEnumValN(mlir::tosa::TosaLevelEnum::None, "none", - "Allows the full range of arguments specified by the operations according " - "to the operation data types.") - )}]> + "sections of the specifciation"> ]; } @@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle }]; } +def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { + let summary = "Attach tosa.target_env information to the given module."; + + let description = [{ + This pass allows the user to specify a TOSA target environment consisting of + the following components: level, profiles and extensions. + + The target environment is attached to the module as an attribute, allowing other + transformations to query the selected target and adapt their behaviour based on + this information. + }]; + + let dependentDialects = [ + "func::FuncDialect", + "tosa::TosaDialect", + ]; + + let options = [ + Option<"level", "level", "mlir::tosa::Level", + /*default=*/"mlir::tosa::Level::eightK", + "The TOSA level that operators should conform to. A TOSA level defines " + "operator argument ranges that an implementation shall support.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::Level::eightK, "8k", + "Ranges are expected to be sufficient for applications with frame " + "sizes up to 8K."), + clEnumValN(mlir::tosa::Level::none, "none", + "Allows the full range of arguments specified by the operations according " + "to the operation data types.") + )}]>, + ListOption<"profiles", "profiles", "std::string", + "The TOSA profile(s) that operators should conform to. TOSA profiles " + "enable efficient implementation on different classes of device. Each " + "profile is an independent set of operations and data type combinations.">, + ListOption<"extensions", "extensions", "std::string", + "The TOSA extension(s) that operators should conform to. TOSA profile " + "extensions define optional operation and data type combinations."> + ]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index c6a3ba9f1439f..e7602b4508cf1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; - validationOptions.profile = {"none"}; - validationOptions.extension = {"none"}; validationOptions.strictOpSpecAlignment = false; validationOptions.allowInvalidOpDatatypeCombinations = false; - validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, validationOptions); diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index c6a438d348946..a95906aa8352e 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaDialect IR/TosaOps.cpp IR/TosaCanonicalizations.cpp + IR/TargetEnv.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp new file mode 100644 index 0000000000000..5aad67173cc61 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -0,0 +1,42 @@ +//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" + +namespace mlir { +namespace tosa { + +TargetEnvAttr lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) + break; + + if (auto attr = op->getAttrOfType(TargetEnvAttr::name)) + return attr; + + op = op->getParentOp(); + } + + return {}; +} + +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { + return TargetEnvAttr::get(context, Level::eightK, + {Profile::pro_int, Profile::pro_fp}, {}); +} + +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = lookupTargetEnv(op)) + return attr; + + return getDefaultTargetEnv(op->getContext()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 803993bb1008d..41b338d6e7189 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaAttachTarget.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp new file mode 100644 index 0000000000000..bcb880a808b36 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -0,0 +1,87 @@ +//===- TosaAttachTarget.cpp +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attach target information to a TOSA module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSAATTACHTARGET +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { + +class TosaAttachTarget + : public tosa::impl::TosaAttachTargetBase { + using Base::Base; + +public: + void runOnOperation() override { + llvm::SmallVector selectedProfiles; + if (!profiles.empty()) { + for (const std::string &prof : profiles) { + std::optional profSymbol = symbolizeProfile(prof); + if (!profSymbol) { + llvm::SmallVector allProfiles = ProfileAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allProfiles, + "profile", prof); + return signalPassFailure(); + } + selectedProfiles.push_back(profSymbol.value()); + } + } + + llvm::SmallVector selectedExtensions; + if (!extensions.empty()) { + for (const std::string &ext : extensions) { + std::optional extSymbol = symbolizeExtension(ext); + if (!extSymbol) { + llvm::SmallVector allExtensions = + ExtensionAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allExtensions, + "extension", ext); + return signalPassFailure(); + } + selectedExtensions.push_back(extSymbol.value()); + } + } + + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + const auto targetEnvAttr = + TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); + } + +private: + template + std::string buildUnkownParameterErrorMessage(llvm::SmallVector &enumValues, + std::string enumName, + std::string unknownArgument) { + std::string message; + llvm::raw_string_ostream os(message); + os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument + << "', supported " << enumName << "s are: "; + llvm::interleaveComma(enumValues, os); + os << "\n"; + return message; + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index c7b9534f9e744..f2dcef15d8517 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Tosa/IR/TargetEnv.h" #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" #include @@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op, return success(); } -struct TosaLevel { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; - int32_t MAX_LOG2_SIZE = 0; - int32_t MAX_NESTING = 0; - int32_t MAX_TENSOR_LIST_SIZE = 0; - - bool operator==(const TosaLevel &rhs) { - return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && - MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && - MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && - MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; - } -}; - -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; -static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; - //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -162,12 +139,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { - this->profile = options.profile; - this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->allowInvalidOpDatatypeCombinations = options.allowInvalidOpDatatypeCombinations; - this->level = options.level; } void runOnOperation() final; @@ -206,7 +180,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) { + if (v > targetEnv.getLevel().MAX_KERNEL) { op->emitOpError() << "failed level check: " << checkDesc; return false; } @@ -214,7 +188,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) { + if (v > targetEnv.getLevel().MAX_STRIDE) { op->emitOpError() << "failed level check: " << checkDesc; return false; } @@ -222,7 +196,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) { + if (v > targetEnv.getLevel().MAX_SCALE) { op->emitOpError() << "failed level check: " << checkDesc; return false; } @@ -230,7 +204,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) { + if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) { op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; return false; @@ -291,6 +265,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { template bool levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); + const TosaLevel tosaLevel = targetEnv.getLevel(); for (auto v : op->getOperands()) { if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)) return false; @@ -472,7 +447,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); - if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; return false; @@ -525,43 +500,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return true; } - // configure profile and level values from pass options profileName and - // levelName - void configLevelAndProfile() { - tosaLevel = TOSA_LEVEL_NONE; - if (level == TosaLevelEnum::EightK) { - tosaLevel = TOSA_LEVEL_EIGHTK; - } - - if (!profile.empty()) { - for (std::string &prof : profile) { - auto profSymbol = symbolizeProfile(prof); - if (profSymbol) { - targetEnv.addProfile(profSymbol.value()); - } else { - llvm::errs() << "unknown TOSA profile name passed in: " << prof - << ", supported profiles are `pro_int` and `pro_fp`\n"; - return signalPassFailure(); - } - } - } - - if (!extension.empty()) { - for (std::string &ext : extension) { - auto extSymbol = symbolizeExtension(ext); - if (extSymbol) { - targetEnv.addExtension(extSymbol.value()); - } else { - llvm::errs() << "unknown TOSA extension name passed in: " << ext - << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " - << "doubleround, inexactround and dynamic\n"; - return signalPassFailure(); - } - } - } - } - bool CheckVariable(Operation *op); bool CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); @@ -569,7 +507,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { SmallVector< std::function> constCheckers; - TosaLevel tosaLevel; DenseMap variablesMap; TosaProfileCompliance profileComp; tosa::TargetEnv targetEnv; @@ -578,11 +515,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { template <> bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto op = tosaOp.getOperation(); - if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)) + if (!levelCheckRank(op, tosaOp.getInput(), "operand", + targetEnv.getLevel().MAX_RANK)) return false; // rank(output) = rank(input) - 1 - if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1)) + if (!levelCheckRank(op, tosaOp.getOutput(), "result", + targetEnv.getLevel().MAX_RANK - 1)) return false; return true; @@ -593,7 +532,8 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { auto op = tosaOp.getOperation(); // Only the condition input has rank limitation. - if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK)) + if (!levelCheckRank(op, tosaOp.getCondition(), "operand", + targetEnv.getLevel().MAX_RANK)) return false; return true; @@ -603,7 +543,8 @@ template <> bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); - if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK)) + if (!levelCheckRank(op, variableType, "variable type", + targetEnv.getLevel().MAX_RANK)) return false; return true; @@ -763,7 +704,8 @@ bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, // defined in 1.7. Levels. // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. - const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; + const int64_t max_size = + (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1; if (size > max_size) { op->emitOpError() << "failed level check: " << operandOrResult @@ -775,7 +717,7 @@ bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { - if (tosaLevel == TOSA_LEVEL_NONE) { + if (targetEnv.getLevel() == TOSA_LEVEL_NONE) { // no need to do level checks return success(); } @@ -1333,12 +1275,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { - configLevelAndProfile(); - TosaDialect *tosaDialect = getContext().getLoadedDialect(); if (!tosaDialect) return; + targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); + getOperation().walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index e23ce43031a24..f73b41cf434fc 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index fad1bec0e3ecc..c57287fa8e368 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment" // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 3bccb32c5b9f4..68d0596b513e7 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_const() -> tensor<1xf32> { // expected-error@+1{{'tosa.const' op expected same attr/result element types}} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 3154f541e0519..40622e6d233c7 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles to focus the verification of expected extension requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0184d2b05f0ee..c824ca0e33753 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 225b962589df9..09e96eca776e2 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index fad4859351251..0bdba0a1d3fbd 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index 9438179622aad..a4ba71a72d798 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_i1() -> tensor<3x11x11x3xi1> { diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir new file mode 100644 index 0000000000000..d6c886c44b013 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K +// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT + +// ----- + +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-LABEL: test_simple +func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { + %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return %1 : tensor<1x1x1x1xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir index cab14201dc0ce..79032f39af99d 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s // -----