Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/cmake/modules/AddMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ function(add_mlir_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

# Declare a dialect in the include directory
function(add_mlir_type_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIR${interface}IncGen)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//

#include "mlir/IR/QuantStorageTypeInterface.h"

#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"

Expand Down
101 changes: 98 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/QuantStorageTypeInterface.td"
include "mlir/IR/CommonTypeConstraints.td"

// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
Expand Down Expand Up @@ -100,7 +101,8 @@ class Builtin_CachedFloatType<string name, string mnemonic,
// Float8E5M2Type
//===----------------------------------------------------------------------===//

def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
["QuantStorageTypeInterface"]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
Expand All @@ -116,6 +118,33 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {

Described in: https://arxiv.org/abs/2209.05433
}];

let extraClassDeclaration = [{
/// QuantStorageTypeInterface method implementations
bool isStorageSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }

/// Get default maximum value for this 8-bit floating point type.
int64_t getDefaultMaximum() const { return 57344; }
/// Get default minimum value for this 8-bit floating point type.
int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }

/// Get the storage type as a string.
std::string getStorageType() const { return "f8E5M2"; }

/// Check if this 8-bit floating point type uses packed representation.
bool isPacked() const { return false; }

/// Get the logical bit width per value for this 8-bit floating point type.
unsigned getLogicalBitWidth() const { return 8; }

/// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
unsigned getElementsPerByte() const { return 1; }

/// Get the preferred alignment in bytes for this 8-bit floating point type.
std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}

//===----------------------------------------------------------------------===//
Expand All @@ -142,7 +171,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
// Float8E4M3FNType
//===----------------------------------------------------------------------===//

def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
["QuantStorageTypeInterface"]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
Expand All @@ -159,6 +189,33 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {

Described in: https://arxiv.org/abs/2209.05433
}];

let extraClassDeclaration = [{
/// QuantStorageTypeInterface method implementations
bool isStorageSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }

/// Get default maximum value for this 8-bit floating point type.
int64_t getDefaultMaximum() const { return 448; }
/// Get default minimum value for this 8-bit floating point type.
int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }

/// Get the storage type as a string.
std::string getStorageType() const { return "f8E4M3FN"; }

/// Check if this 8-bit floating point type uses packed representation.
bool isPacked() const { return false; }

/// Get the logical bit width per value for this 8-bit floating point type.
unsigned getLogicalBitWidth() const { return 8; }

/// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
unsigned getElementsPerByte() const { return 1; }

/// Get the preferred alignment in bytes for this 8-bit floating point type.
std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -497,7 +554,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//

def Builtin_Integer : Builtin_Type<"Integer", "integer",
[VectorElementTypeInterface]> {
[VectorElementTypeInterface, QuantStorageTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
Expand Down Expand Up @@ -554,6 +611,44 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;

/// QuantStorageTypeInterface method implementations
/// Return true if this is a signed or signless integer type.
bool isStorageSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
unsigned getStorageWidth() const { return getWidth(); }

/// Get default maximum value for this integer type.
int64_t getDefaultMaximum() const {
if (isStorageSigned()) {
return llvm::maxIntN(getStorageWidth());
}
return llvm::maxUIntN(getStorageWidth());
}
/// Get default minimum value for this integer type.
int64_t getDefaultMinimum() const {
if (isStorageSigned()) {
return llvm::minIntN(getStorageWidth());
}
return 0;
}

/// Get the storage type as a string.
std::string getStorageType() const {
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
}

/// Check if this integer type uses packed representation.
bool isPacked() const { return false; }

/// Get the logical bit width per value for this integer type.
unsigned getLogicalBitWidth() const { return getWidth(); }

/// Get the number of logical elements that fit in one byte for this integer type.
unsigned getElementsPerByte() const { return 1; }

/// Get the preferred alignment in bytes for this integer type.
std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)

add_mlir_type_interface(QuantStorageTypeInterface)

set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/IR/QuantStorageTypeInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- QuantStorageTypeInterface.h - Quantzation Interfaces --------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_QuantStorageTypeInterface_H
#define MLIR_IR_QuantStorageTypeInterface_H

#include "mlir/IR/Types.h"

// Forward declarations for the types we need in the implementation
namespace mlir {
class IntegerType;
} // namespace mlir

#include "mlir/IR/QuantStorageTypeInterface.h.inc"

#endif // MLIR_IR_QuantStorageTypeInterface_H
69 changes: 69 additions & 0 deletions mlir/include/mlir/IR/QuantStorageTypeInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef MLIR_IR_QUANTSTORAGETYPEINTERFACE
#define MLIR_IR_QUANTSTORAGETYPEINTERFACE

include "mlir/IR/OpBase.td"

def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
This interface provides methods to determine storage characteristics for quantization purposes,
including packing behavior, and alignment requirements.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Check if the storage type is signed.
Returns true if the type represents signed values, false for unsigned.
}],
"bool", "isStorageSigned", (ins)>,

InterfaceMethod<[{
Get the bit width of this type.
Returns the number of bits used to store values of this type.
}],
"unsigned", "getStorageWidth", (ins)>,

InterfaceMethod<[{
Get default minimum value for this type.
}],
"int64_t", "getDefaultMinimum", (ins)>,

InterfaceMethod<[{
Get default maximum value for this type.
}],
"int64_t", "getDefaultMaximum", (ins)>,

InterfaceMethod<[{
Get the storage type as a string.
}],
"std::string", "getStorageType", (ins)>,

InterfaceMethod<[{
Check if the storage type uses packed representation.
Returns true if multiple values are packed into one byte (e.g., sub-byte types),
false if value uses full byte.
}],
"bool", "isPacked", (ins)>,

InterfaceMethod<[{
Get the logical bit width per value.
For packed sub-byte types, this may differ from getStorageWidth().
}],
"unsigned", "getLogicalBitWidth", (ins)>,

InterfaceMethod<[{
Get the number of logical elements that fit in one byte.
For packed sub-byte types, this returns how many values can be stored per byte.
}],
"unsigned", "getElementsPerByte", (ins)>,

InterfaceMethod<[{
Returns the preferred alignment for this type, in bytes.
}],
"std::optional<unsigned>", "getPreferredAlignmentBytes", (ins)>
];

}

#endif // MLIR_IR_QUANTSTORAGETYPEINTERFACE
70 changes: 34 additions & 36 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/IR/QuantStorageTypeInterface.h"

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -46,32 +47,27 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
unsigned integralWidth = intStorageType.getWidth();

// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;

// Verify storageTypeMin and storageTypeMax.
bool isSigned =
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
int64_t defaultIntegerMin =
getDefaultMinimumForInteger(isSigned, integralWidth);
int64_t defaultIntegerMax =
getDefaultMaximumForInteger(isSigned, integralWidth);
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
return emitError() << "illegal storage min and storage max: ("
<< storageTypeMin << ":" << storageTypeMax << ")";
if (auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();

// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;

int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();

if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
storageTypeMax > defaultMax) {
return emitError() << "illegal storage min and storage max: ("
<< storageTypeMin << ":" << storageTypeMax << ")";
}

return success();
}
return success();

return emitError() << "storage type must implement QuantStorageTypeInterface";
}

Type QuantizedType::getStorageType() const {
Expand All @@ -87,20 +83,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
}

bool QuantizedType::hasStorageTypeBounds() const {
unsigned int integralWidth = getStorageTypeIntegralWidth();
bool isSignedInteger = isSigned();
int64_t defaultIntegerMin =
getDefaultMinimumForInteger(isSignedInteger, integralWidth);
int64_t defaultIntegerMax =
getDefaultMaximumForInteger(isSignedInteger, integralWidth);
return defaultIntegerMin != getStorageTypeMin() ||
defaultIntegerMax != getStorageTypeMax();
Type storageType = static_cast<ImplType *>(impl)->storageType;
auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(storageType);

int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();

return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}

unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
Type storageType = static_cast<ImplType *>(impl)->storageType;
auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(storageType);

return quantStorageTypeInterface.getStorageWidth();
}

Type QuantizedType::getExpressedType() const {
Expand Down
Loading