Skip to content

Conversation

Roman-Pevnyi
Copy link

Currently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows.

Current Approach and Its Limitations:

  • The present implementation relies on dynamic checks (e.g., type switches or if-else chains) to determine the storage type and retrieve type-specific information for legality checks.

  • This approach works for a small, fixed set of types, but as the number of supported types grows, the code becomes harder to read, maintain, and extend.

Proposed Interface-Based Approach:

  • Define a StorageTypeInterface that specifies the required methods any storage type must implement to be used in UniformQuantizedType.
  • Each storage type (Integer, Float8E5M2, Float8E4M3FN, and new types like NF4) would implement this interface, encapsulating their type-specific logic.
  • When UniformQuantizedType needs to check legality or retrieve information, it can use MLIR’s dyn_cast mechanism to check if the type implements the interface and then call the required methods.
  • This design decouples UniformQuantizedType from the specifics of each storage type, making it easy to add new types (such as NF4) without modifying the core logic or introducing more type checks.

Benefits:

  • Extensibility: New storage types can be added by simply implementing the interface, without touching the core UniformQuantizedType logic.
  • Readability: The code is cleaner, as it avoids large switch statements or if-else chains.
  • Maintainability: Type-specific logic is encapsulated within each type, reducing the risk of errors and making the codebase easier to understand and update.

… built in types. Updated parser and printer in Quant dialect
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Aug 11, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: None (Roman-Pevnyi)

Changes

Currently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows.

Current Approach and Its Limitations:

  • The present implementation relies on dynamic checks (e.g., type switches or if-else chains) to determine the storage type and retrieve type-specific information for legality checks.

  • This approach works for a small, fixed set of types, but as the number of supported types grows, the code becomes harder to read, maintain, and extend.

Proposed Interface-Based Approach:

  • Define a StorageTypeInterface that specifies the required methods any storage type must implement to be used in UniformQuantizedType.
  • Each storage type (Integer, Float8E5M2, Float8E4M3FN, and new types like NF4) would implement this interface, encapsulating their type-specific logic.
  • When UniformQuantizedType needs to check legality or retrieve information, it can use MLIR’s dyn_cast mechanism to check if the type implements the interface and then call the required methods.
  • This design decouples UniformQuantizedType from the specifics of each storage type, making it easy to add new types (such as NF4) without modifying the core logic or introducing more type checks.

Benefits:

  • Extensibility: New storage types can be added by simply implementing the interface, without touching the core UniformQuantizedType logic.
  • Readability: The code is cleaner, as it avoids large switch statements or if-else chains.
  • Maintainability: Type-specific logic is encapsulated within each type, reducing the risk of errors and making the codebase easier to understand and update.

Full diff: https://github.com/llvm/llvm-project/pull/152966.diff

10 Files Affected:

  • (modified) mlir/cmake/modules/AddMLIR.cmake (+8)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+2)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+28-1)
  • (modified) mlir/include/mlir/IR/CMakeLists.txt (+2)
  • (added) mlir/include/mlir/IR/QuantizationInterface.h (+22)
  • (added) mlir/include/mlir/IR/QuantizationInterface.td (+44)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+35-30)
  • (modified) mlir/lib/Dialect/Quant/IR/TypeParser.cpp (+40-34)
  • (modified) mlir/lib/IR/CMakeLists.txt (+3-1)
  • (added) mlir/lib/IR/QuantizationInterface.cpp (+23)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
   add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
 endfunction()
 
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+  set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+  mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+  mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+  add_public_tablegen_target(MLIR${interface}IncGen)
+  add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
 
 # Generate Documentation
 function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
 // Tablegen Type Declarations
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/QuantizationInterface.h"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.h.inc"
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
 include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface, QuantizationInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+    /// QuantizationInterface method implementations
+    /// Return true if this is a signed integer type.
+    bool isStorageSigned() const { return !isUnsigned(); }
+    /// Get the bit width of this integer type.
+    unsigned getStorageWidth() const { return getWidth(); }
+    
+    /// Get default minimum value for this integer type.
+    int64_t getDefaultMinimum() const {
+      if (isStorageSigned()) {
+        return llvm::minIntN(getStorageWidth());
+      }
+      return 0;
+    }
+    /// Get default maximum value for this integer type.
+    int64_t getDefaultMaximum() const {
+      if (isStorageSigned()) {
+        return llvm::maxIntN(getStorageWidth());
+      }
+      return llvm::maxUIntN(getStorageWidth());
+    }
+    
+    /// Get the storage type as a string.
+    std::string getStorageType() const {
+      return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+add_mlir_type_interface(QuantizationInterface)
+
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
 mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
 mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+  let description = [{
+    Interface for types that can be used as storage types in Quant dialect.
+    This interface provides methods to determine storage characteristics for quantization purposes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Check if the storage type is signed.
+      Returns true if the type represents signed values, false for unsigned.
+    }],
+    "bool", "isStorageSigned", (ins)>,
+
+    InterfaceMethod<[{
+      Get the bit width of this integer type.
+      Returns the number of bits used to store values of this type.
+    }],
+    "unsigned", "getStorageWidth", (ins)>,
+
+    InterfaceMethod<[{
+      Get default minimum value for this integer type.
+    }],
+    "int64_t", "getDefaultMinimum", (ins)>,
+
+    InterfaceMethod<[{
+      Get default maximum value for this integer type.
+    }],
+    "int64_t", "getDefaultMaximum", (ins)>,
+
+    InterfaceMethod<[{
+      Get the storage type as a string.
+    }],
+    "std::string", "getStorageType", (ins)>
+  ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
   if (!intStorageType)
     return emitError() << "storage type must be integral";
-  unsigned integralWidth = intStorageType.getWidth();
-
-  // Verify storage width.
-  if (integralWidth == 0 || integralWidth > MaxStorageBits)
-    return emitError() << "illegal storage type size: " << integralWidth;
-
-  // Verify storageTypeMin and storageTypeMax.
-  bool isSigned =
-      (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSigned, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSigned, integralWidth);
-  if (storageTypeMax - storageTypeMin <= 0 ||
-      storageTypeMin < defaultIntegerMin ||
-      storageTypeMax > defaultIntegerMax) {
-    return emitError() << "illegal storage min and storage max: ("
-                       << storageTypeMin << ":" << storageTypeMax << ")";
+
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+    // Verify storage width.
+    if (integralWidth == 0 || integralWidth > MaxStorageBits)
+      return emitError() << "illegal storage type size: " << integralWidth;
+
+    int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+    int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+    if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+        storageTypeMax > defaultMax) {
+      return emitError() << "illegal storage min and storage max: ("
+                         << storageTypeMin << ":" << storageTypeMax << ")";
+    }
+
+    return success();
   }
-  return success();
+
+  return emitError() << "storage type must implement QuantizationInterface";
 }
 
 Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
 }
 
 bool QuantizedType::hasStorageTypeBounds() const {
-  unsigned int integralWidth = getStorageTypeIntegralWidth();
-  bool isSignedInteger = isSigned();
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSignedInteger, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSignedInteger, integralWidth);
-  return defaultIntegerMin != getStorageTypeMin() ||
-         defaultIntegerMax != getStorageTypeMax();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+  int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+  return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
 }
 
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
-  // NOTE: If ever supporting non-integral storage types, some other scheme
-  // for determining the width will be needed.
-  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  return quantizationInterface.getStorageWidth();
 }
 
 Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
 
 using namespace mlir;
 using namespace quant;
 
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   auto typeLoc = parser.getCurrentLocation();
-  IntegerType type;
+  Type type;
 
   // Parse storage type (alpha_ident, integer_literal).
   StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   if (result.has_value()) {
     if (!succeeded(*result))
       return nullptr;
-    isSigned = !type.isUnsigned();
-    storageTypeWidth = type.getWidth();
-  } else if (succeeded(parser.parseKeyword(&identifier))) {
-    // Otherwise, this must be an unsigned integer (`u` integer-literal).
-    if (!identifier.consume_front("u")) {
-      parser.emitError(typeLoc, "illegal storage type prefix");
+
+    if (auto quantizationInterface =
+            llvm::dyn_cast<QuantizationInterface>(type)) {
+      isSigned = quantizationInterface.isStorageSigned();
+      storageTypeWidth = quantizationInterface.getStorageWidth();
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    if (identifier.getAsInteger(10, storageTypeWidth)) {
-      parser.emitError(typeLoc, "expected storage type width");
+  } else if (succeeded(parser.parseKeyword(&identifier))) {
+    // Otherwise, this must be an unsigned integer (`u` integer-literal)
+    if (identifier.consume_front("u")) {
+      if (identifier.getAsInteger(10, storageTypeWidth)) {
+        parser.emitError(typeLoc, "expected storage type width");
+        return nullptr;
+      }
+      isSigned = false;
+      type = parser.getBuilder().getIntegerType(storageTypeWidth);
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    isSigned = false;
-    type = parser.getBuilder().getIntegerType(storageTypeWidth);
   } else {
     return nullptr;
   }
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   return type;
 }
 
-static ParseResult parseStorageRange(DialectAsmParser &parser,
-                                     IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
                                      int64_t &storageTypeMin,
                                      int64_t &storageTypeMax) {
-  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
-      isSigned, storageType.getWidth());
-  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
-      isSigned, storageType.getWidth());
+  int64_t defaultMin, defaultMax;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    defaultMin = quantizationInterface.getDefaultMinimum();
+    defaultMax = quantizationInterface.getDefaultMaximum();
+  }
+
   if (failed(parser.parseOptionalLess())) {
-    storageTypeMin = defaultIntegerMin;
-    storageTypeMax = defaultIntegerMax;
+    storageTypeMin = defaultMin;
+    storageTypeMax = defaultMax;
     return success();
   }
 
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
       parser.getCurrentLocation(&maxLoc) ||
       parser.parseInteger(storageTypeMax) || parser.parseGreater())
     return failure();
-  if (storageTypeMin < defaultIntegerMin) {
+  if (storageTypeMin < defaultMin) {
     return parser.emitError(minLoc, "illegal storage type minimum: ")
            << storageTypeMin;
   }
-  if (storageTypeMax > defaultIntegerMax) {
+  if (storageTypeMax > defaultMax) {
     return parser.emitError(maxLoc, "illegal storage type maximum: ")
            << storageTypeMax;
   }
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
 ///   storage-type ::= (`i` | `u`) integer-literal
 ///   expressed-type-spec ::= `:` `f` integer-literal
 static Type parseAnyType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
 ///     scale-zero-tensor (`,` scale-zero-tensor)*
 ///   `}`
 static Type parseUniformType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
 
 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   // storage type
-  unsigned storageWidth = type.getStorageTypeIntegralWidth();
-  bool isSigned = type.isSigned();
-  if (isSigned) {
-    out << "i" << storageWidth;
-  } else {
-    out << "u" << storageWidth;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+    out << quantizationInterface.getStorageType();
   }
 
   // storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
   OperationSupport.cpp
   PatternLoggingListener.cpp
   PatternMatch.cpp
+  QuantizationInterface.cpp  
   Region.cpp
   RegionKindInterface.cpp
   SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
-
+  MLIRQuantizationInterfaceIncGen
+  
   LINK_LIBS PUBLIC
   MLIRSupport
   )
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"

@llvmbot
Copy link
Member

llvmbot commented Aug 11, 2025

@llvm/pr-subscribers-mlir

Author: None (Roman-Pevnyi)

Changes

Currently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows.

Current Approach and Its Limitations:

  • The present implementation relies on dynamic checks (e.g., type switches or if-else chains) to determine the storage type and retrieve type-specific information for legality checks.

  • This approach works for a small, fixed set of types, but as the number of supported types grows, the code becomes harder to read, maintain, and extend.

Proposed Interface-Based Approach:

  • Define a StorageTypeInterface that specifies the required methods any storage type must implement to be used in UniformQuantizedType.
  • Each storage type (Integer, Float8E5M2, Float8E4M3FN, and new types like NF4) would implement this interface, encapsulating their type-specific logic.
  • When UniformQuantizedType needs to check legality or retrieve information, it can use MLIR’s dyn_cast mechanism to check if the type implements the interface and then call the required methods.
  • This design decouples UniformQuantizedType from the specifics of each storage type, making it easy to add new types (such as NF4) without modifying the core logic or introducing more type checks.

Benefits:

  • Extensibility: New storage types can be added by simply implementing the interface, without touching the core UniformQuantizedType logic.
  • Readability: The code is cleaner, as it avoids large switch statements or if-else chains.
  • Maintainability: Type-specific logic is encapsulated within each type, reducing the risk of errors and making the codebase easier to understand and update.

Full diff: https://github.com/llvm/llvm-project/pull/152966.diff

10 Files Affected:

  • (modified) mlir/cmake/modules/AddMLIR.cmake (+8)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+2)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+28-1)
  • (modified) mlir/include/mlir/IR/CMakeLists.txt (+2)
  • (added) mlir/include/mlir/IR/QuantizationInterface.h (+22)
  • (added) mlir/include/mlir/IR/QuantizationInterface.td (+44)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+35-30)
  • (modified) mlir/lib/Dialect/Quant/IR/TypeParser.cpp (+40-34)
  • (modified) mlir/lib/IR/CMakeLists.txt (+3-1)
  • (added) mlir/lib/IR/QuantizationInterface.cpp (+23)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
   add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
 endfunction()
 
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+  set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+  mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+  mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+  add_public_tablegen_target(MLIR${interface}IncGen)
+  add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
 
 # Generate Documentation
 function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
 // Tablegen Type Declarations
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/QuantizationInterface.h"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.h.inc"
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
 include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface, QuantizationInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+    /// QuantizationInterface method implementations
+    /// Return true if this is a signed integer type.
+    bool isStorageSigned() const { return !isUnsigned(); }
+    /// Get the bit width of this integer type.
+    unsigned getStorageWidth() const { return getWidth(); }
+    
+    /// Get default minimum value for this integer type.
+    int64_t getDefaultMinimum() const {
+      if (isStorageSigned()) {
+        return llvm::minIntN(getStorageWidth());
+      }
+      return 0;
+    }
+    /// Get default maximum value for this integer type.
+    int64_t getDefaultMaximum() const {
+      if (isStorageSigned()) {
+        return llvm::maxIntN(getStorageWidth());
+      }
+      return llvm::maxUIntN(getStorageWidth());
+    }
+    
+    /// Get the storage type as a string.
+    std::string getStorageType() const {
+      return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+add_mlir_type_interface(QuantizationInterface)
+
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
 mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
 mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+  let description = [{
+    Interface for types that can be used as storage types in Quant dialect.
+    This interface provides methods to determine storage characteristics for quantization purposes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Check if the storage type is signed.
+      Returns true if the type represents signed values, false for unsigned.
+    }],
+    "bool", "isStorageSigned", (ins)>,
+
+    InterfaceMethod<[{
+      Get the bit width of this integer type.
+      Returns the number of bits used to store values of this type.
+    }],
+    "unsigned", "getStorageWidth", (ins)>,
+
+    InterfaceMethod<[{
+      Get default minimum value for this integer type.
+    }],
+    "int64_t", "getDefaultMinimum", (ins)>,
+
+    InterfaceMethod<[{
+      Get default maximum value for this integer type.
+    }],
+    "int64_t", "getDefaultMaximum", (ins)>,
+
+    InterfaceMethod<[{
+      Get the storage type as a string.
+    }],
+    "std::string", "getStorageType", (ins)>
+  ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
   if (!intStorageType)
     return emitError() << "storage type must be integral";
-  unsigned integralWidth = intStorageType.getWidth();
-
-  // Verify storage width.
-  if (integralWidth == 0 || integralWidth > MaxStorageBits)
-    return emitError() << "illegal storage type size: " << integralWidth;
-
-  // Verify storageTypeMin and storageTypeMax.
-  bool isSigned =
-      (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSigned, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSigned, integralWidth);
-  if (storageTypeMax - storageTypeMin <= 0 ||
-      storageTypeMin < defaultIntegerMin ||
-      storageTypeMax > defaultIntegerMax) {
-    return emitError() << "illegal storage min and storage max: ("
-                       << storageTypeMin << ":" << storageTypeMax << ")";
+
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+    // Verify storage width.
+    if (integralWidth == 0 || integralWidth > MaxStorageBits)
+      return emitError() << "illegal storage type size: " << integralWidth;
+
+    int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+    int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+    if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+        storageTypeMax > defaultMax) {
+      return emitError() << "illegal storage min and storage max: ("
+                         << storageTypeMin << ":" << storageTypeMax << ")";
+    }
+
+    return success();
   }
-  return success();
+
+  return emitError() << "storage type must implement QuantizationInterface";
 }
 
 Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
 }
 
 bool QuantizedType::hasStorageTypeBounds() const {
-  unsigned int integralWidth = getStorageTypeIntegralWidth();
-  bool isSignedInteger = isSigned();
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSignedInteger, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSignedInteger, integralWidth);
-  return defaultIntegerMin != getStorageTypeMin() ||
-         defaultIntegerMax != getStorageTypeMax();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+  int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+  return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
 }
 
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
-  // NOTE: If ever supporting non-integral storage types, some other scheme
-  // for determining the width will be needed.
-  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  return quantizationInterface.getStorageWidth();
 }
 
 Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
 
 using namespace mlir;
 using namespace quant;
 
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   auto typeLoc = parser.getCurrentLocation();
-  IntegerType type;
+  Type type;
 
   // Parse storage type (alpha_ident, integer_literal).
   StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   if (result.has_value()) {
     if (!succeeded(*result))
       return nullptr;
-    isSigned = !type.isUnsigned();
-    storageTypeWidth = type.getWidth();
-  } else if (succeeded(parser.parseKeyword(&identifier))) {
-    // Otherwise, this must be an unsigned integer (`u` integer-literal).
-    if (!identifier.consume_front("u")) {
-      parser.emitError(typeLoc, "illegal storage type prefix");
+
+    if (auto quantizationInterface =
+            llvm::dyn_cast<QuantizationInterface>(type)) {
+      isSigned = quantizationInterface.isStorageSigned();
+      storageTypeWidth = quantizationInterface.getStorageWidth();
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    if (identifier.getAsInteger(10, storageTypeWidth)) {
-      parser.emitError(typeLoc, "expected storage type width");
+  } else if (succeeded(parser.parseKeyword(&identifier))) {
+    // Otherwise, this must be an unsigned integer (`u` integer-literal)
+    if (identifier.consume_front("u")) {
+      if (identifier.getAsInteger(10, storageTypeWidth)) {
+        parser.emitError(typeLoc, "expected storage type width");
+        return nullptr;
+      }
+      isSigned = false;
+      type = parser.getBuilder().getIntegerType(storageTypeWidth);
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    isSigned = false;
-    type = parser.getBuilder().getIntegerType(storageTypeWidth);
   } else {
     return nullptr;
   }
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   return type;
 }
 
-static ParseResult parseStorageRange(DialectAsmParser &parser,
-                                     IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
                                      int64_t &storageTypeMin,
                                      int64_t &storageTypeMax) {
-  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
-      isSigned, storageType.getWidth());
-  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
-      isSigned, storageType.getWidth());
+  int64_t defaultMin, defaultMax;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    defaultMin = quantizationInterface.getDefaultMinimum();
+    defaultMax = quantizationInterface.getDefaultMaximum();
+  }
+
   if (failed(parser.parseOptionalLess())) {
-    storageTypeMin = defaultIntegerMin;
-    storageTypeMax = defaultIntegerMax;
+    storageTypeMin = defaultMin;
+    storageTypeMax = defaultMax;
     return success();
   }
 
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
       parser.getCurrentLocation(&maxLoc) ||
       parser.parseInteger(storageTypeMax) || parser.parseGreater())
     return failure();
-  if (storageTypeMin < defaultIntegerMin) {
+  if (storageTypeMin < defaultMin) {
     return parser.emitError(minLoc, "illegal storage type minimum: ")
            << storageTypeMin;
   }
-  if (storageTypeMax > defaultIntegerMax) {
+  if (storageTypeMax > defaultMax) {
     return parser.emitError(maxLoc, "illegal storage type maximum: ")
            << storageTypeMax;
   }
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
 ///   storage-type ::= (`i` | `u`) integer-literal
 ///   expressed-type-spec ::= `:` `f` integer-literal
 static Type parseAnyType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
 ///     scale-zero-tensor (`,` scale-zero-tensor)*
 ///   `}`
 static Type parseUniformType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
 
 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   // storage type
-  unsigned storageWidth = type.getStorageTypeIntegralWidth();
-  bool isSigned = type.isSigned();
-  if (isSigned) {
-    out << "i" << storageWidth;
-  } else {
-    out << "u" << storageWidth;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+    out << quantizationInterface.getStorageType();
   }
 
   // storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
   OperationSupport.cpp
   PatternLoggingListener.cpp
   PatternMatch.cpp
+  QuantizationInterface.cpp  
   Region.cpp
   RegionKindInterface.cpp
   SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
-
+  MLIRQuantizationInterfaceIncGen
+  
   LINK_LIBS PUBLIC
   MLIRSupport
   )
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"

@rengolin
Copy link
Member

rengolin commented Aug 11, 2025

I think we need to decide first if we want a contiguous storage type for sub-byte types or not.

For example:

  • int4 and fp4 can have size = 4, storage_size = 8 but still pack two elements per byte.
  • fp6 can be represented in two lists (4 + 2 bits), and those lists themselves be packed or not.

Depending on the answers is what this interface would look like.

The second question is if we want to have an MX type (tuple of vectors, with payload, scaling factor, storage type and element type). If we do, then the conversion between the MX and non-MX would be in an MX dialect (potentially quant), and if we make MX types native in MLIR, then in theory, we could tile and fuse them by teaching those patterns to descend into the sub-types.

I don't mind experimenting with it like your implementation, but it would be good to know what folks would prefer as a final destination, so that we go all to the same place.

PS: This is something I wrote last year, so outdated, but has a notion of an MX type.

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other types could be used as storage types in the future? I'm not very familiar with quantization, but I've always thought of the storage type as just a "bag of bits" without structure. In that case, an integer type would be sufficient.


def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this interface would live in the quant dialect. It would then be attached to the IntegerType as an external model. Unfortunately, that makes it impossible to declare the interface as "promised" due to layering constraints.

Copy link
Contributor

@javedabsar1 javedabsar1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Roman-Pevnyi
Copy link
Author

What other types could be used as storage types in the future? I'm not very familiar with quantization, but I've always thought of the storage type as just a "bag of bits" without structure. In that case, an integer type would be sufficient.

Thanks everyone for the review!
@matthias-springer, the current UniformQuantizedType in MLIR indeed treats storage as a simple integer type - just a "bag of bits", which works well for uniform quantization where values are interpreted via a scale and zero point.

However, types like NF4 represent non-uniform quantization. The storage is still a small integer (NF4 - 4 bits), but the interpretation is done via a lookup table (LUT) of quantile bin centers rather than a linear formula. This means the quantized type carries additional structure beyond just the storage bits:

  • Storage type: (i4).
  • Expressed type: (f32 or f16).
  • Bins: an attribute holding the LUT values (quantiles).

Thus, the quantization semantics live in the quantized type that references the LUT and storage type remains an integer to hold quantiles.

@Roman-Pevnyi
Copy link
Author

Other than for Builtin_Integer is it possible to show another use-case (example) for QuantizationInterface ?

@javedabsar1, I will add another use-case other than Buitin_Integer to see the picture better.

@ZoranZomborat
Copy link

ZoranZomborat commented Aug 12, 2025

decide first if we want a contiguous storage type for sub-byte types or not.

Hi @rengolin; from our compiler experience I can argue that we'd like to have the freedom do have data either contiguously packed or aligned to a byte boundary and unpacked, and be able to make distinction between these two cases from user code;

We have our own dialect to handle constant transformations and lazy folding https://github.com/openvinotoolkit/npu_compiler/tree/develop/src/vpux_compiler/src/dialect/const
where we have a similar decoupling of storage type and expressed type;

One of the first steps when bringing those constants in out const dialect is to unpack them and store them as separate bytes; So it's similar to the decoupling in quant dialect, only that in const dialect, the expressed type is i4 while the storage type is i8;

And having this data unpacked in separate bytes allows us to perform all of the data movement transformations on them(slice, reorder, concat, pad) , without having to add explicit sub-byte logic to each transformation which would be a massive undertaking;
Then towards the end of compilation we decide to pack them back and prepare the for HW execution;

So similarly I believe in MLIR we should be able to support both scenarios.

@ZoranZomborat
Copy link

The second question is if we want to have an MX type

Great topic! We just got some MXFP4 models on our table and started to consider how to enable them.
Personally to me it's confusing that we've all agreed to slap a datatype name to a particularity of quantization;

MXFP4 for example is just FP4 e2m1 datatype, which has already support across SW and HW platforms, with the only particularity that the quant scales are per group and FP8 E8M0 compared to your normal FP32 scales;

The general challenge I see is setting the distinction of what should be a datatype detail and what should be quantization detail;
Take for example this older thread https://discourse.llvm.org/t/rfc-add-suport-for-quantilequantizedtype-in-quant-dialect/80346 debating how the quantile LUT aspect of the NF4 and similar datatypes should be represented;

Since then, we went the direction of making the quantile LUT an aspect of quantization, but because this compounds together with the uniform/per_axis quantization, we're seeing an unnecessary complexity in the quantization class hierarchies and would rather refactor and make quantile LUT an aspect of the datatype itself, and just reuse quantile uniform/per_axis.

Take also the new class for Sub Channel Quantization https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694 added recently to MLIR. There's all the reasons to use this class in combination with NF4 datatype; so to avoid complicated class hierarchies it makes more sense to have quantile LUT a detail of the datatype.

Then for MXFP4, I'd propose to also reuse Sub Channel Quantization;
!qalias = !quant.uniform<Float4E2M1FN:f32:{0:1, 1:2}, {{2.0:120,3.0:127}, {4.0,5.0}}> (just copied a case from LIT test we can work together to define something that makes more sense)
with added information in the representation to cover also the type of the scales and zero points.

@rengolin what are your thoughts on this direction?

@rengolin
Copy link
Member

So similarly I believe in MLIR we should be able to support both scenarios.

Agreed. This may need a method in the interface (isPacked()), an attribute in the ops/types that implement it and some mutually agreed ABI that teaches ops/rewrites how to unpack it.

Since then, we went the direction of making the quantile LUT an aspect of quantization, but because this compounds together with the uniform/per_axis quantization, we're seeing an unnecessary complexity in the quantization class hierarchies and would rather refactor and make quantile LUT an aspect of the datatype itself, and just reuse quantile uniform/per_axis.

Interesting. We have a prototype that explodes the tensor into {storage, factor} and if both are tensors, we need to pass them along and tile/fuse them in "smart" ways, so having a type to carry that info would be preferred. Your experience adds to the idea to propose an actual MXFP type: a composite type where some rewrites can see through (ex. tiling) while others need helps (ex. fusion, LUT decomposition).

The storage is still a small integer (NF4 - 4 bits), but the interpretation is done via a lookup table (LUT) of quantile bin centers rather than a linear formula. This means the quantized type carries additional structure beyond just the storage bits:

  • Storage type: (i4).
  • Expressed type: (f32 or f16).
  • Bins: an attribute holding the LUT values (quantiles).

What would be the storage type of fp6? Could it be a tuple ((i4, i2))? Would it be easy to create a tuple of tensors from a tuple type, instead of a tensor of tuples? We'll need some extended logic for MX types that need to be quite pervasive through MLIR rewrites and ops.

@shahidact
Copy link
Contributor

Then for MXFP4, I'd propose to also reuse Sub Channel Quantization;
!qalias = !quant.uniform<Float4E2M1FN:f32:{0:1, 1:2}, {{2.0:120,3.0:127}, {4.0,5.0}}> (just copied a case from LIT test we can work together to define something that makes more sense)
with added information in the representation to cover also the type of the scales and zero points.

Thanks @ZoranZomborat, How do we represent scale or zero_point if these are SSA value which could be another use case for dynamic quantization?

@Roman-Pevnyi
Copy link
Author

I’ve added support for Float8E5M2 and Float8E4M3FN.
As shown in the second commit, this only required extending the float types themselves by implementing the QuantizationInterface methods. The Quant dialect code was not touched because it can already accept any type that implements this interface.

I will present below NF4, which is not a built-in type and has a defined structure.
You can see the current NF4 type implementation here.
NF4 consists of:

  • Storage type: i4,
  • Quantile type: f16,
  • Bins: an attribute holding the LUT values (quantiles).

It already has get, print, and parse methods implemented.
To make NF4 usable as a UniformQuantizedType storage type, I would add the UniformQuantizedType method implementations, following the same approach used for the float8 formats in the second commit.

Example code skeleton:

class NF4Type : public mlir::Type::TypeBase<
                    NF4Type,
                    QuantileFloatType,
                    vpux::detail::QuantileFloatTypeStorage,
                    mlir::QuantizationInterface::Trait> {
  /// Existing code...

  // QuantizationInterface method implementations
  bool isStorageSigned() const { return true; }
  /// Get the bit width of this 4-bit normalized floating point type.
  unsigned getStorageWidth() const { return 4; }

  /// Get default maximum value for this 4-bit normalized floating point type.
  int64_t getDefaultMaximum() const { return 16; }
  /// Get default minimum value for this 4-bit normalized floating point type.
  int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }

  /// Get the storage type name as a string.
  std::string getStorageType() const {
    std::string result = "!QuantileFloat.nf4<";
    llvm::raw_string_ostream os(result);

    os << getStorageType();
    os << ":";
    os << getQuantileType();
    os << ", {";

    ArrayRef<double> quantiles = this->getQuantiles();
    printQuantiles(quantiles);

    os << "}>";
    os.flush();
    return result;
}
};

Example IR:

!qalias = !quant.uniform<!QuantileFloat.nf4<ui4:f16, {-1.000000e+00, ..., 1.000000e+00}> : f16, 1.0 : 0>
!qalias = !quant.uniform<!QuantileFloat.nf4<storage_type:quantiles_type, {quantiles}> : expressed_type, scale : zero_point>

@ZoranZomborat
Copy link

How do we represent scale or zero_point if these are SSA value which could be another use case for dynamic quantization?

Great topic also!
What we did at least, in the dynamic quantization case we defined a new operation set and shifted from having the quant details as part of the statically defined datatype to receiving them as arguments;

So we have the static quant operations as:

def DequantizeOp :

    let summary = "Dequantize layer";
    let arguments = (ins
        RankedTensorOf<[quant_QuantizedType]>:$input,
        TypeAttr:$dstElemType
    );

and the dynamic quant operations as:

def DynamicDequantizeOp :

    let summary = "Dynamic Dequantize layer";
    let arguments = (ins
        RankedTensorOf<[quant_QuantizedType]>:$input,
        RankedTensorOf<[F16, F32]>:$scale, // to be extended in the future to support also FP8 E8M0 scales;
        Optional<RankedTensorOf<[I8, I4, I2, U2]>>:$zp,

        TypeAttr:$dstElemType
    );

where for the input type of dynamic dequantize we use uniform quant type with dummy scales/zp, to at least signal the storage type information.

!qElemType = !quant.uniform<i4:f16, 1.0:0>
%dynamic_dequant = DynamicDequantize(%weights, %scale, %zp) {dstElemType = f16} :
        tensor<16x16x1x1x!qElemType>, tensor<16x1x1x1xf16>, tensor<1x16x1x1xi4> -> tensor<16x16x1x1xf16>

While I believe we could do something cleaner for the dummy quant type, either use a raw integer tensor <16x16x1x1xi4> or find a way to omit the scales/zp if not static !quant.uniform<i4:f16> we should still keep the quant params as direct attributes to the operations that handle the dynamic quantization.
Any thoughts on this?

@shahidact
Copy link
Contributor

Any thoughts on this?

On top of my mind I have this

  1. Define a non-uniform quant type say quant.nonuniform which is a tuple of quant.uniform i.e (quant.uniform, quant.uniform) whose 1st element signify the value and the 2nd element is scale.
  2. Define accessor ops quant.ExtractValueOp, quant.ExtractScaleOp
  3. Define quant.ScaleOp to dynamically compute the scale for a given value and return a value of quant.nonuniform type

OR we can better have such a tuple type in main MLIR however we need to weigh-in the trade-off.

@anuragsingh-tt
Copy link

Great discussion.

We have a similar approach to @ZoranZomborat for static/dynamic quantization in our compiler in that the storage type is separated during lowering and quantization parameters (scales, LUTs, etc.) are handled at the op/type level. I agree with the idea that storage type should just be a clean carrier that describes how the bits are stored (signedness, width, packing). In the NF4 case described above, the LUT could be an operand in the dynamic case and an attribute in the static case.

To make new storage types easy to use in generic passes and legal for different backends, maybe it would be helpful if the interface exposed a few basic packing/alignment facts:
bool isPacked() as suggested
unsigned getLogicalBitWidth() (e.g., 4 for NF4)
unsigned getElementsPerByte() (e.g., 2 for NF4)
Optional<unsigned> getPreferredAlignmentBytes() (for DMA/vector/tile alignment)
These allow MLIR transforms to compute correct byte strides, respect sub-byte legality rules and insert pack/unpack only where necessary. I realize some of these may lean toward backend implementation details but even simple defaults would let generic passes make safe decisions while still allowing hardware-specific encodings to refine them.

BTW, there are also lower-bit formats like i2 and i1 gaining traction.
• For ternary, the “canonical” compact encoding is 2 bits per value so isPacked=true, bitWidth=2, elementsPerByte=4. Though if a backend stores ternary as signed int8 (−1,0,+1) for simplicity it would be isPacked=false, bitWidth=8, elementsPerByte=1. The interface makes both legal and the layout/encodings decide which is used where.

Misc. point: If someone brings up 6‑bit (fp6) formats those generally require a composite packing scheme (ex: 4+2 bit streams) and don’t divide 8 cleanly. That case argues for either (a) a richer packing descriptor than just elementsPerByte, or (b) modeling fp6 as a composite/“MX” type (as discussed above) rather than a single simple carrier. So these helpers might be useful for the common divisors {1,2,4,8} but outliers would need a separate path.

@ZoranZomborat
Copy link

Misc. point: If someone brings up 6‑bit (fp6) formats those generally require a composite packing scheme (ex: 4+2 bit streams) and don’t divide 8 cleanly. That case argues for either (a) a richer packing descriptor than just elementsPerByte, or (b) modeling fp6 as a composite/“MX” type (as discussed above) rather than a single simple carrier. So these helpers might be useful for the common divisors {1,2,4,8} but outliers would need a separate path.

Sounds good, and yeah FP6 seems hard for sure; not only do you need a combination of 4+2 bit packing; you'd also need to know at which point in the fp6 sequence you are so to correctly decode:
[[6+2],[4,4],[2,6] ....] definitely a head scratchier;
As you said also int2 is getting a lot of traction and there are more and more 4 bit float variants;
Even if HW has some native int6, fp6 processing, I see a lot of challenges into getting this functional and optimal.

@ZoranZomborat
Copy link

Related to MX types, I still would argue that we can separate their datatype and quantization details, and represent themm, for example MXFP4 as simple f4E2M1FN storage together with sub channel quantization scheme.

Like in the original proposal https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694

Let's imagine a tensor of with the following specs: 
  //   tensor<16384x3072xMXFP4>
  //   quantizationDimensions : [0,1]
  //   blockSizes: [1,32]
  //   scales: [[s0/0, s0/95], [s1/0,s1/95], .., [s16383/0, ... s16383/95]] : tensor<16384x96xf8E8M0FN>

Then we'd have the following
tensor<16384x3072x!quant.uniform<i8:f32:{0:1, 1:32}:{f8E8M0FN}, {{s0/0, s0/95}, {s1/0,s1/95}, .., {s16383/0, ... s16383/95}}>> 

Please correct me if I'm missing any MXFP detail, which can't be covered by above representation;

Also even more interesting, back to the discussion around DynamicDequantize/Dequantize operations; we're discovering that with the continuous increase in quantization params size, it becomes more optimal to provide them at runtime as operands rather then attributes to these operations, even in the static scenario;
This of course would differ based on HW, compiler, runtime implementations, but we're seeing better performance handling them as such and having a more finer grain control over their scheduling;
Not a general rule to follow, but it could be a scenario where representing MXFP types on a dialect closer to HW, we'd no longer choose to embed the scale information as part of the tensor/datatype information

@Roman-Pevnyi Roman-Pevnyi force-pushed the quant_dialect_add_quantization_interface branch from 6e4b4e8 to d7447a8 Compare August 20, 2025 11:28
@Roman-Pevnyi
Copy link
Author

In last commit I have added 4 interface methods which expose a few basic packing and alignment facts.
@anuragsingh-tt please take a look.

@anuragsingh-tt
Copy link

we'd no longer choose to embed the scale information as part of the tensor/datatype information

+1. Hardwares will extract these separately.

Define a non-uniform quant type say quant.nonuniform which is a tuple of quant.uniform i.e (quant.uniform, quant.uniform) whose 1st element signify the value and the 2nd element is scale.

What are the issues with instead treating scales/zps as SSA operands on the consuming ops in the case of dynamic quantization?

@shahidact
Copy link
Contributor

What are the issues with instead treating scales/zps as SSA operands on the consuming ops in the case of dynamic quantization?

The fundamental issue is that quantization is a type-level property, not just an operational detail. When we strip this information from types and relegate it to operation operands, we lose:

  • Compiler reasoning capabilities
  • Type safety guarantees
  • Optimization opportunities
  • Clean abstraction boundaries

@anuragsingh-tt
Copy link

The fundamental issue is that quantization is a type-level property, not just an operational detail. When we strip this information from types and relegate it to operation operands, we lose:

Quant-as-type is useful to preserve model intent from a frontend perspective but it is ultimately a numerical transformation eventually lowered to plain ints + params by hardware.

  • Compiler reasoning capabilities

Wouldn't exposing scale/zp as operands increase reasoning opportunities (e.g: CSE, LICM) because they operate on values?

  • Type safety guarantees

This is indeed a plus of the tuple-type approach. Though the guarantees that we can ensure (storage type vs. expressed type legality, axis legal in per channel case, etc.) can still be type parameters?

@shahidact
Copy link
Contributor

Wouldn't exposing scale/zp as operands increase reasoning opportunities (e.g: CSE, LICM) because they operate on values?

Yes, off course, however we don't want to expose at higher level as we want to keep the semantic and pattern for any potential transformation at this level.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants