Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
tosa::TosaValidationOptions{
{"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
tosa::TosaValidationOptions{false, false});

/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
Expand Down
56 changes: 50 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,67 @@
namespace mlir {
namespace tosa {

struct TosaLevel {
int32_t MAX_RANK = 0;
int32_t MAX_KERNEL = 0;
int32_t MAX_STRIDE = 0;
int32_t MAX_SCALE = 0;
int32_t MAX_LOG2_SIZE = 0;
int32_t MAX_NESTING = 0;
int32_t MAX_TENSOR_LIST_SIZE = 0;

bool operator==(const TosaLevel &rhs) {
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
MAX_NESTING == rhs.MAX_NESTING &&
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
}
};

static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
63, 256, 256};

TargetEnvAttr lookupTargetEnv(Operation *op);
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);

/// Queries the target environment recursively from enclosing symbol table ops
/// containing the given `op` or returns the default target environment as
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);

/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level.
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
const SmallVectorImpl<Extension> &extensions) {
explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
: level(level) {
enabledProfiles.insert_range(profiles);

enabledExtensions.insert_range(extensions);
}

explicit TargetEnv(TargetEnvAttr targetAttr)
: TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
targetAttr.getExtensions()) {}

void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }

// TODO implement the following utilities.
// Version getSpecVersion() const;
// TosaLevel getLevel() const;

TosaLevel getLevel() const {
if (level == Level::eightK)
return TOSA_LEVEL_EIGHTK;
else if (level == Level::none)
return TOSA_LEVEL_NONE;
else
llvm_unreachable("Unknown TOSA level");
};

// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
Expand All @@ -62,8 +105,9 @@ class TargetEnv {
}

private:
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 8> enabledExtensions;
llvm::SmallSet<Extension, 13> enabledExtensions;
};

} // namespace tosa
Expand Down
50 changes: 44 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;

def Tosa_ProfileAttr
: Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
[Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
let extraClassDeclaration = [{
static llvm::SmallVector<Profile, 2> getAllValues() {
return {Profile::pro_int, Profile::pro_fp};
}
}];
}

def Tosa_ProfileArrayAttr
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;

def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
Expand All @@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
]>;
]> {
let extraClassDeclaration = [{
static llvm::SmallVector<Extension, 11> getAllValues() {
return {
Extension::int16, Extension::int4, Extension::bf16,
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
Extension::variable, Extension::controlflow, Extension::doubleround,
Extension::inexactround, Extension::dynamic
};
}
}];
}

def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;

def Tosa_ProfileAttr
: Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
[Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;

def Tosa_ProfileArrayAttr
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
def Tosa_LevelAttr
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;

// The base class for defining op availability dimensions.
class Availability {
Expand Down Expand Up @@ -381,6 +404,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}

//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);

let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}

//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 0 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTosaPassIncGen)
add_dependencies(mlir-headers MLIRTosaPassIncGen)

Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"

namespace mlir {
Expand Down
64 changes: 41 additions & 23 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
}];
}

def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
[
I32EnumAttrCase<"None", 0, "none">,
I32EnumAttrCase<"EightK", 1, "8k">,
]>{
let cppNamespace = "mlir::tosa";
}

def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
Expand All @@ -81,28 +73,14 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
}];

let options = [
ListOption<"profile", "profile", "std::string",
"Validate if operations match for the given profile set">,
ListOption<"extension", "extension", "std::string",
"Validate if operations match for the given extension set">,
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
"sections of the specifciation">,
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
"Validate if operator parameters are within specfication for the given level",
[{::llvm::cl::values(
clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
"Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
"Allows the full range of arguments specified by the operations according "
"to the operation data types.")
)}]>
"sections of the specifciation">
];
}

Expand Down Expand Up @@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
}];
}

def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
let summary = "Attach tosa.target_env information to the given module.";

let description = [{
This pass allows the user to specify a TOSA target environment consisting of
the following components: level, profiles and extensions.

The target environment is attached to the module as an attribute, allowing other
transformations to query the selected target and adapt their behaviour based on
this information.
}];

let dependentDialects = [
"func::FuncDialect",
"tosa::TosaDialect",
];

let options = [
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
"operator argument ranges that an implementation shall support.",
[{::llvm::cl::values(
clEnumValN(mlir::tosa::Level::eightK, "8k",
"Ranges are expected to be sufficient for applications with frame "
"sizes up to 8K."),
clEnumValN(mlir::tosa::Level::none, "none",
"Allows the full range of arguments specified by the operations according "
"to the operation data types.")
)}]>,
ListOption<"profiles", "profiles", "std::string",
"The TOSA profile(s) that operators should conform to. TOSA profiles "
"enable efficient implementation on different classes of device. Each "
"profile is an independent set of operations and data type combinations.">,
ListOption<"extensions", "extensions", "std::string",
"The TOSA extension(s) that operators should conform to. TOSA profile "
"extensions define optional operation and data type combinations.">
];
}

#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
3 changes: 0 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
TosaValidationOptions validationOptions;
validationOptions.profile = {"none"};
validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp

Expand Down
42 changes: 42 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TargetEnv.h"

namespace mlir {
namespace tosa {

TargetEnvAttr lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
if (!op)
break;

if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
return attr;

op = op->getParentOp();
}

return {};
}

TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
return TargetEnvAttr::get(context, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}

TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
if (auto attr = lookupTargetEnv(op))
return attr;

return getDefaultTargetEnv(op->getContext());
}

} // namespace tosa
} // namespace mlir
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
Expand Down
Loading