-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect #152966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect #152966
Conversation
… built in types. Updated parser and printer in Quant dialect
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-ods Author: None (Roman-Pevnyi) ChangesCurrently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows. Current Approach and Its Limitations:
Proposed Interface-Based Approach:
Benefits:
Full diff: https://github.com/llvm/llvm-project/pull/152966.diff 10 Files Affected:
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+ set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+ mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+ mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+ add_public_tablegen_target(MLIR${interface}IncGen)
+ add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
+#include "mlir/IR/QuantizationInterface.h"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface]> {
+ [VectorElementTypeInterface, QuantizationInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+ /// QuantizationInterface method implementations
+ /// Return true if this is a signed integer type.
+ bool isStorageSigned() const { return !isUnsigned(); }
+ /// Get the bit width of this integer type.
+ unsigned getStorageWidth() const { return getWidth(); }
+
+ /// Get default minimum value for this integer type.
+ int64_t getDefaultMinimum() const {
+ if (isStorageSigned()) {
+ return llvm::minIntN(getStorageWidth());
+ }
+ return 0;
+ }
+ /// Get default maximum value for this integer type.
+ int64_t getDefaultMaximum() const {
+ if (isStorageSigned()) {
+ return llvm::maxIntN(getStorageWidth());
+ }
+ return llvm::maxUIntN(getStorageWidth());
+ }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const {
+ return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+add_mlir_type_interface(QuantizationInterface)
+
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+ let description = [{
+ Interface for types that can be used as storage types in Quant dialect.
+ This interface provides methods to determine storage characteristics for quantization purposes.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Check if the storage type is signed.
+ Returns true if the type represents signed values, false for unsigned.
+ }],
+ "bool", "isStorageSigned", (ins)>,
+
+ InterfaceMethod<[{
+ Get the bit width of this integer type.
+ Returns the number of bits used to store values of this type.
+ }],
+ "unsigned", "getStorageWidth", (ins)>,
+
+ InterfaceMethod<[{
+ Get default minimum value for this integer type.
+ }],
+ "int64_t", "getDefaultMinimum", (ins)>,
+
+ InterfaceMethod<[{
+ Get default maximum value for this integer type.
+ }],
+ "int64_t", "getDefaultMaximum", (ins)>,
+
+ InterfaceMethod<[{
+ Get the storage type as a string.
+ }],
+ "std::string", "getStorageType", (ins)>
+ ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
- unsigned integralWidth = intStorageType.getWidth();
-
- // Verify storage width.
- if (integralWidth == 0 || integralWidth > MaxStorageBits)
- return emitError() << "illegal storage type size: " << integralWidth;
-
- // Verify storageTypeMin and storageTypeMax.
- bool isSigned =
- (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSigned, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSigned, integralWidth);
- if (storageTypeMax - storageTypeMin <= 0 ||
- storageTypeMin < defaultIntegerMin ||
- storageTypeMax > defaultIntegerMax) {
- return emitError() << "illegal storage min and storage max: ("
- << storageTypeMin << ":" << storageTypeMax << ")";
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+ // Verify storage width.
+ if (integralWidth == 0 || integralWidth > MaxStorageBits)
+ return emitError() << "illegal storage type size: " << integralWidth;
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+ storageTypeMax > defaultMax) {
+ return emitError() << "illegal storage min and storage max: ("
+ << storageTypeMin << ":" << storageTypeMax << ")";
+ }
+
+ return success();
}
- return success();
+
+ return emitError() << "storage type must implement QuantizationInterface";
}
Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
}
bool QuantizedType::hasStorageTypeBounds() const {
- unsigned int integralWidth = getStorageTypeIntegralWidth();
- bool isSignedInteger = isSigned();
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSignedInteger, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSignedInteger, integralWidth);
- return defaultIntegerMin != getStorageTypeMin() ||
- defaultIntegerMax != getStorageTypeMax();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
- // NOTE: If ever supporting non-integral storage types, some other scheme
- // for determining the width will be needed.
- return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ return quantizationInterface.getStorageWidth();
}
Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
using namespace mlir;
using namespace quant;
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
auto typeLoc = parser.getCurrentLocation();
- IntegerType type;
+ Type type;
// Parse storage type (alpha_ident, integer_literal).
StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (result.has_value()) {
if (!succeeded(*result))
return nullptr;
- isSigned = !type.isUnsigned();
- storageTypeWidth = type.getWidth();
- } else if (succeeded(parser.parseKeyword(&identifier))) {
- // Otherwise, this must be an unsigned integer (`u` integer-literal).
- if (!identifier.consume_front("u")) {
- parser.emitError(typeLoc, "illegal storage type prefix");
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type)) {
+ isSigned = quantizationInterface.isStorageSigned();
+ storageTypeWidth = quantizationInterface.getStorageWidth();
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- if (identifier.getAsInteger(10, storageTypeWidth)) {
- parser.emitError(typeLoc, "expected storage type width");
+ } else if (succeeded(parser.parseKeyword(&identifier))) {
+ // Otherwise, this must be an unsigned integer (`u` integer-literal)
+ if (identifier.consume_front("u")) {
+ if (identifier.getAsInteger(10, storageTypeWidth)) {
+ parser.emitError(typeLoc, "expected storage type width");
+ return nullptr;
+ }
+ isSigned = false;
+ type = parser.getBuilder().getIntegerType(storageTypeWidth);
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- isSigned = false;
- type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
return nullptr;
}
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
return type;
}
-static ParseResult parseStorageRange(DialectAsmParser &parser,
- IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
- int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
- isSigned, storageType.getWidth());
- int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
- isSigned, storageType.getWidth());
+ int64_t defaultMin, defaultMax;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ defaultMin = quantizationInterface.getDefaultMinimum();
+ defaultMax = quantizationInterface.getDefaultMaximum();
+ }
+
if (failed(parser.parseOptionalLess())) {
- storageTypeMin = defaultIntegerMin;
- storageTypeMax = defaultIntegerMax;
+ storageTypeMin = defaultMin;
+ storageTypeMax = defaultMax;
return success();
}
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
parser.getCurrentLocation(&maxLoc) ||
parser.parseInteger(storageTypeMax) || parser.parseGreater())
return failure();
- if (storageTypeMin < defaultIntegerMin) {
+ if (storageTypeMin < defaultMin) {
return parser.emitError(minLoc, "illegal storage type minimum: ")
<< storageTypeMin;
}
- if (storageTypeMax > defaultIntegerMax) {
+ if (storageTypeMax > defaultMax) {
return parser.emitError(maxLoc, "illegal storage type maximum: ")
<< storageTypeMax;
}
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// scale-zero-tensor (`,` scale-zero-tensor)*
/// `}`
static Type parseUniformType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
- unsigned storageWidth = type.getStorageTypeIntegralWidth();
- bool isSigned = type.isSigned();
- if (isSigned) {
- out << "i" << storageWidth;
- } else {
- out << "u" << storageWidth;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+ out << quantizationInterface.getStorageType();
}
// storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
+ QuantizationInterface.cpp
Region.cpp
RegionKindInterface.cpp
SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
-
+ MLIRQuantizationInterfaceIncGen
+
LINK_LIBS PUBLIC
MLIRSupport
)
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"
|
@llvm/pr-subscribers-mlir Author: None (Roman-Pevnyi) ChangesCurrently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows. Current Approach and Its Limitations:
Proposed Interface-Based Approach:
Benefits:
Full diff: https://github.com/llvm/llvm-project/pull/152966.diff 10 Files Affected:
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+ set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+ mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+ mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+ add_public_tablegen_target(MLIR${interface}IncGen)
+ add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
+#include "mlir/IR/QuantizationInterface.h"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface]> {
+ [VectorElementTypeInterface, QuantizationInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+ /// QuantizationInterface method implementations
+ /// Return true if this is a signed integer type.
+ bool isStorageSigned() const { return !isUnsigned(); }
+ /// Get the bit width of this integer type.
+ unsigned getStorageWidth() const { return getWidth(); }
+
+ /// Get default minimum value for this integer type.
+ int64_t getDefaultMinimum() const {
+ if (isStorageSigned()) {
+ return llvm::minIntN(getStorageWidth());
+ }
+ return 0;
+ }
+ /// Get default maximum value for this integer type.
+ int64_t getDefaultMaximum() const {
+ if (isStorageSigned()) {
+ return llvm::maxIntN(getStorageWidth());
+ }
+ return llvm::maxUIntN(getStorageWidth());
+ }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const {
+ return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+add_mlir_type_interface(QuantizationInterface)
+
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+ let description = [{
+ Interface for types that can be used as storage types in Quant dialect.
+ This interface provides methods to determine storage characteristics for quantization purposes.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Check if the storage type is signed.
+ Returns true if the type represents signed values, false for unsigned.
+ }],
+ "bool", "isStorageSigned", (ins)>,
+
+ InterfaceMethod<[{
+ Get the bit width of this integer type.
+ Returns the number of bits used to store values of this type.
+ }],
+ "unsigned", "getStorageWidth", (ins)>,
+
+ InterfaceMethod<[{
+ Get default minimum value for this integer type.
+ }],
+ "int64_t", "getDefaultMinimum", (ins)>,
+
+ InterfaceMethod<[{
+ Get default maximum value for this integer type.
+ }],
+ "int64_t", "getDefaultMaximum", (ins)>,
+
+ InterfaceMethod<[{
+ Get the storage type as a string.
+ }],
+ "std::string", "getStorageType", (ins)>
+ ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
- unsigned integralWidth = intStorageType.getWidth();
-
- // Verify storage width.
- if (integralWidth == 0 || integralWidth > MaxStorageBits)
- return emitError() << "illegal storage type size: " << integralWidth;
-
- // Verify storageTypeMin and storageTypeMax.
- bool isSigned =
- (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSigned, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSigned, integralWidth);
- if (storageTypeMax - storageTypeMin <= 0 ||
- storageTypeMin < defaultIntegerMin ||
- storageTypeMax > defaultIntegerMax) {
- return emitError() << "illegal storage min and storage max: ("
- << storageTypeMin << ":" << storageTypeMax << ")";
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+ // Verify storage width.
+ if (integralWidth == 0 || integralWidth > MaxStorageBits)
+ return emitError() << "illegal storage type size: " << integralWidth;
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+ storageTypeMax > defaultMax) {
+ return emitError() << "illegal storage min and storage max: ("
+ << storageTypeMin << ":" << storageTypeMax << ")";
+ }
+
+ return success();
}
- return success();
+
+ return emitError() << "storage type must implement QuantizationInterface";
}
Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
}
bool QuantizedType::hasStorageTypeBounds() const {
- unsigned int integralWidth = getStorageTypeIntegralWidth();
- bool isSignedInteger = isSigned();
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSignedInteger, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSignedInteger, integralWidth);
- return defaultIntegerMin != getStorageTypeMin() ||
- defaultIntegerMax != getStorageTypeMax();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
- // NOTE: If ever supporting non-integral storage types, some other scheme
- // for determining the width will be needed.
- return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ return quantizationInterface.getStorageWidth();
}
Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
using namespace mlir;
using namespace quant;
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
auto typeLoc = parser.getCurrentLocation();
- IntegerType type;
+ Type type;
// Parse storage type (alpha_ident, integer_literal).
StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (result.has_value()) {
if (!succeeded(*result))
return nullptr;
- isSigned = !type.isUnsigned();
- storageTypeWidth = type.getWidth();
- } else if (succeeded(parser.parseKeyword(&identifier))) {
- // Otherwise, this must be an unsigned integer (`u` integer-literal).
- if (!identifier.consume_front("u")) {
- parser.emitError(typeLoc, "illegal storage type prefix");
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type)) {
+ isSigned = quantizationInterface.isStorageSigned();
+ storageTypeWidth = quantizationInterface.getStorageWidth();
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- if (identifier.getAsInteger(10, storageTypeWidth)) {
- parser.emitError(typeLoc, "expected storage type width");
+ } else if (succeeded(parser.parseKeyword(&identifier))) {
+ // Otherwise, this must be an unsigned integer (`u` integer-literal)
+ if (identifier.consume_front("u")) {
+ if (identifier.getAsInteger(10, storageTypeWidth)) {
+ parser.emitError(typeLoc, "expected storage type width");
+ return nullptr;
+ }
+ isSigned = false;
+ type = parser.getBuilder().getIntegerType(storageTypeWidth);
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- isSigned = false;
- type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
return nullptr;
}
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
return type;
}
-static ParseResult parseStorageRange(DialectAsmParser &parser,
- IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
- int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
- isSigned, storageType.getWidth());
- int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
- isSigned, storageType.getWidth());
+ int64_t defaultMin, defaultMax;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ defaultMin = quantizationInterface.getDefaultMinimum();
+ defaultMax = quantizationInterface.getDefaultMaximum();
+ }
+
if (failed(parser.parseOptionalLess())) {
- storageTypeMin = defaultIntegerMin;
- storageTypeMax = defaultIntegerMax;
+ storageTypeMin = defaultMin;
+ storageTypeMax = defaultMax;
return success();
}
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
parser.getCurrentLocation(&maxLoc) ||
parser.parseInteger(storageTypeMax) || parser.parseGreater())
return failure();
- if (storageTypeMin < defaultIntegerMin) {
+ if (storageTypeMin < defaultMin) {
return parser.emitError(minLoc, "illegal storage type minimum: ")
<< storageTypeMin;
}
- if (storageTypeMax > defaultIntegerMax) {
+ if (storageTypeMax > defaultMax) {
return parser.emitError(maxLoc, "illegal storage type maximum: ")
<< storageTypeMax;
}
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// scale-zero-tensor (`,` scale-zero-tensor)*
/// `}`
static Type parseUniformType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
- unsigned storageWidth = type.getStorageTypeIntegralWidth();
- bool isSigned = type.isSigned();
- if (isSigned) {
- out << "i" << storageWidth;
- } else {
- out << "u" << storageWidth;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+ out << quantizationInterface.getStorageType();
}
// storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
+ QuantizationInterface.cpp
Region.cpp
RegionKindInterface.cpp
SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
-
+ MLIRQuantizationInterfaceIncGen
+
LINK_LIBS PUBLIC
MLIRSupport
)
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"
|
I think we need to decide first if we want a contiguous storage type for sub-byte types or not. For example:
Depending on the answers is what this interface would look like. The second question is if we want to have an MX type (tuple of vectors, with payload, scaling factor, storage type and element type). If we do, then the conversion between the MX and non-MX would be in an MX dialect (potentially I don't mind experimenting with it like your implementation, but it would be good to know what folks would prefer as a final destination, so that we go all to the same place. PS: This is something I wrote last year, so outdated, but has a notion of an MX type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other types could be used as storage types in the future? I'm not very familiar with quantization, but I've always thought of the storage type as just a "bag of bits" without structure. In that case, an integer type would be sufficient.
|
||
def QuantizationInterface : TypeInterface<"QuantizationInterface"> { | ||
let description = [{ | ||
Interface for types that can be used as storage types in Quant dialect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, this interface would live in the quant
dialect. It would then be attached to the IntegerType
as an external model. Unfortunately, that makes it impossible to declare the interface as "promised" due to layering constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this in addition to https://discourse.llvm.org/t/rfc-extending-uniformquantizedtype-with-interface-based-support-for-new-storage-types-in-quant-dialect/87803/1.
I have some comments -- mostly inquiries and suggestions.
Thanks everyone for the review! However, types like NF4 represent non-uniform quantization. The storage is still a small integer (NF4 - 4 bits), but the interpretation is done via a lookup table (LUT) of quantile bin centers rather than a linear formula. This means the quantized type carries additional structure beyond just the storage bits:
Thus, the quantization semantics live in the quantized type that references the LUT and storage type remains an integer to hold quantiles. |
@javedabsar1, I will add another use-case other than Buitin_Integer to see the picture better. |
Hi @rengolin; from our compiler experience I can argue that we'd like to have the freedom do have data either contiguously packed or aligned to a byte boundary and unpacked, and be able to make distinction between these two cases from user code; We have our own dialect to handle constant transformations and lazy folding https://github.com/openvinotoolkit/npu_compiler/tree/develop/src/vpux_compiler/src/dialect/const One of the first steps when bringing those constants in out const dialect is to unpack them and store them as separate bytes; So it's similar to the decoupling in quant dialect, only that in const dialect, the expressed type is i4 while the storage type is i8; And having this data unpacked in separate bytes allows us to perform all of the data movement transformations on them(slice, reorder, concat, pad) , without having to add explicit sub-byte logic to each transformation which would be a massive undertaking; So similarly I believe in MLIR we should be able to support both scenarios. |
Great topic! We just got some MXFP4 models on our table and started to consider how to enable them. MXFP4 for example is just FP4 e2m1 datatype, which has already support across SW and HW platforms, with the only particularity that the quant scales are per group and FP8 E8M0 compared to your normal FP32 scales; The general challenge I see is setting the distinction of what should be a datatype detail and what should be quantization detail; Since then, we went the direction of making the quantile LUT an aspect of quantization, but because this compounds together with the uniform/per_axis quantization, we're seeing an unnecessary complexity in the quantization class hierarchies and would rather refactor and make quantile LUT an aspect of the datatype itself, and just reuse quantile uniform/per_axis. Take also the new class for Sub Channel Quantization https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694 added recently to MLIR. There's all the reasons to use this class in combination with NF4 datatype; so to avoid complicated class hierarchies it makes more sense to have quantile LUT a detail of the datatype. Then for MXFP4, I'd propose to also reuse Sub Channel Quantization; @rengolin what are your thoughts on this direction? |
Agreed. This may need a method in the interface (
Interesting. We have a prototype that explodes the tensor into
What would be the storage type of |
Thanks @ZoranZomborat, How do we represent |
I’ve added support for Float8E5M2 and Float8E4M3FN. I will present below NF4, which is not a built-in type and has a defined structure.
It already has get, print, and parse methods implemented. Example code skeleton:
Example IR:
|
Great topic also! So we have the static quant operations as:
and the dynamic quant operations as:
where for the input type of dynamic dequantize we use uniform quant type with dummy scales/zp, to at least signal the storage type information.
While I believe we could do something cleaner for the dummy quant type, either use a raw integer tensor |
On top of my mind I have this
OR we can better have such a tuple type in main MLIR however we need to weigh-in the trade-off. |
Great discussion. We have a similar approach to @ZoranZomborat for static/dynamic quantization in our compiler in that the storage type is separated during lowering and quantization parameters (scales, LUTs, etc.) are handled at the op/type level. I agree with the idea that storage type should just be a clean carrier that describes how the bits are stored (signedness, width, packing). In the NF4 case described above, the LUT could be an operand in the dynamic case and an attribute in the static case. To make new storage types easy to use in generic passes and legal for different backends, maybe it would be helpful if the interface exposed a few basic packing/alignment facts: BTW, there are also lower-bit formats like i2 and i1 gaining traction. Misc. point: If someone brings up 6‑bit ( |
Sounds good, and yeah FP6 seems hard for sure; not only do you need a combination of 4+2 bit packing; you'd also need to know at which point in the fp6 sequence you are so to correctly decode: |
Related to MX types, I still would argue that we can separate their datatype and quantization details, and represent themm, for example MXFP4 as simple f4E2M1FN storage together with sub channel quantization scheme. Like in the original proposal https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694
Please correct me if I'm missing any MXFP detail, which can't be covered by above representation; Also even more interesting, back to the discussion around DynamicDequantize/Dequantize operations; we're discovering that with the continuous increase in quantization params size, it becomes more optimal to provide them at runtime as operands rather then attributes to these operations, even in the static scenario; |
6e4b4e8
to
d7447a8
Compare
In last commit I have added 4 interface methods which expose a few basic packing and alignment facts. |
+1. Hardwares will extract these separately.
What are the issues with instead treating scales/zps as SSA operands on the consuming ops in the case of dynamic quantization? |
The fundamental issue is that quantization is a type-level property, not just an operational detail. When we strip this information from types and relegate it to operation operands, we lose:
|
Quant-as-type is useful to preserve model intent from a frontend perspective but it is ultimately a numerical transformation eventually lowered to plain ints + params by hardware.
Wouldn't exposing scale/zp as operands increase reasoning opportunities (e.g: CSE, LICM) because they operate on values?
This is indeed a plus of the tuple-type approach. Though the guarantees that we can ensure (storage type vs. expressed type legality, axis legal in per channel case, etc.) can still be type parameters? |
Yes, off course, however we don't want to expose at higher level as we want to keep the semantic and pattern for any potential transformation at this level. |
Currently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows.
Current Approach and Its Limitations:
The present implementation relies on dynamic checks (e.g., type switches or if-else chains) to determine the storage type and retrieve type-specific information for legality checks.
This approach works for a small, fixed set of types, but as the number of supported types grows, the code becomes harder to read, maintain, and extend.
Proposed Interface-Based Approach:
Benefits: